jax.ops module

jax.ops module#

The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

Segment reduction operators#

segment_max(data, segment_ids[, ...])

Computes the maximum within segments of an array.

segment_min(data, segment_ids[, ...])

Computes the minimum within segments of an array.

segment_prod(data, segment_ids[, ...])

Computes the product within segments of an array.

segment_sum(data, segment_ids[, ...])

Computes the sum within segments of an array.