jax.ops.index_addΒΆ

jax.ops.index_add(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]ΒΆ

Pure equivalent of x[idx] += y.

Returns the value of x that would result from the NumPy-style indexed assignment:

x[idx] += y

Note the index_add operator is pure; x itself is not modified, instead the new value that x would have taken is returned.

Unlike the NumPy code x[idx] += y, if multiple indices refer to the same location the updates will be summed. (NumPy would only apply the last update, rather than summing the updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).

Parameters
  • x – an array with the values to be updated.

  • idx – a Numpy-style index, consisting of None, integers, slice objects, ellipses, ndarrays with integer dtypes, or a tuple of the above. A convenient syntactic sugar for forming indices is via the jax.ops.index object.

  • y – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].

  • indices_are_sorted – whether idx is known to be sorted

  • unique_indices – whether idx is known to be free of duplicates

Returns

An array.

>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_add(x, jax.ops.index[2:4, 3:], 6.)
array([[1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 7., 7., 7.],
       [1., 1., 1., 7., 7., 7.],
       [1., 1., 1., 1., 1., 1.]], dtype=float32)