jax.ops package

Indexed update operators

JAX is intended to be used with a functional style of programming, and hence does not support NumPy-style indexed assignment directly. Instead, JAX provides pure alternatives, namely jax.ops.index_update() and its relatives.

index Helper object for building indexes for indexed update functions.
index_update(x, idx, y) Pure equivalent of x[idx] = y.
index_add(x, idx, y) Pure equivalent of x[idx] += y.
index_min(x, idx, y) Pure equivalent of x[idx] = minimum(x[idx], y).
index_max(x, idx, y) Pure equivalent of x[idx] = maximum(x[idx], y).

Other operators

segment_sum(data, segment_ids[, num_segments]) Computes the sum within segments of an array.