jax.ops package

Indexed update operators

JAX is intended to be used with a functional style of programming, and does not support NumPy-style indexed assignment directly. Instead, JAX provides alternative pure functional operators for indexed updates to arrays.

JAX array types have a property at, which can be used as follows (where idx is a NumPy index expression).

Alternate syntax

Equivalent in-place expression




x[idx] = y


x[idx] += y


x[idx] *= y


x[idx] /= y


x[idx] **= y


x[idx] = np.minimum(x[idx], y)


x[idx] = np.maximum(x[idx], y)

None of these expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

By default, JAX assumes that all indices are in-bounds. There is experimental support for giving more precise semantics to out-of-bounds indexed accesses, via the mode parameter to functions such as get and set. Valid values for mode include "clip", which means that out-of-bounds indices will be clamped into range, and "fill"/"drop", which are aliases and mean that out-of-bounds reads will be filled with a scalar fill_value, and out-of-bounds writes will be discarded.

Indexed update functions (deprecated)

The following functions are aliases for the x.at[idx].set(y) style operators. Use the x.at[idx] operators instead.


Helper object for building indexes for indexed update functions.

index_update(x, idx, y[, …])

Pure equivalent of x[idx] = y.

index_add(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] += y.

index_mul(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] *= y.

index_min(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] = minimum(x[idx], y).

index_max(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] = maximum(x[idx], y).

Other operators

segment_max(data, segment_ids[, …])

Computes the maximum within segments of an array.

segment_min(data, segment_ids[, …])

Computes the minimum within segments of an array.

segment_prod(data, segment_ids[, …])

Computes the product within segments of an array.

segment_sum(data, segment_ids[, …])

Computes the sum within segments of an array.