jax.ops package

Indexed update operators

JAX is intended to be used with a functional style of programming, and hence does not support NumPy-style indexed assignment directly. Instead, JAX provides pure alternatives, namely jax.ops.index_update() and its relatives.

index

Helper object for building indexes for indexed update functions.

index_update(x, idx, y[, …])

Pure equivalent of x[idx] = y.

index_add(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] += y.

index_mul(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] *= y.

index_min(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] = minimum(x[idx], y).

index_max(x, idx, y[, indices_are_sorted, …])

Pure equivalent of x[idx] = maximum(x[idx], y).

Syntactic sugar for indexed update operators

JAX also provides an alternate syntax for these indexed update operators. Specifically, JAX ndarray types have a property at, which can be used as follows (where idx can be an arbitrary index expression).

Alternate syntax

Equivalent expression

x.at[idx].set(y)

jax.ops.index_update(x, jax.ops.index[idx], y)

x.at[idx].add(y)

jax.ops.index_add(x, jax.ops.index[idx], y)

x.at[idx].mul(y)

jax.ops.index_mul(x, jax.ops.index[idx], y)

x.at[idx].min(y)

jax.ops.index_min(x, jax.ops.index[idx], y)

x.at[idx].max(y)

jax.ops.index_max(x, jax.ops.index[idx], y)

Note that none of these expressions modify the original x; instead they return a modified copy of x.

Other operators

segment_sum(data, segment_ids[, …])

Computes the sum within segments of an array.