jax.ops.index_update

jax.ops.index_update(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]

Pure equivalent of x[idx] = y.

Returns the value of x that would result from the NumPy-style indexed assignment:

x[idx] = y

Note the index_update operator is pure; x itself is not modified, instead the new value that x would have taken is returned.

Unlike NumPy’s x[idx] = y, if multiple indices refer to the same location it is undefined which update is chosen; JAX may choose the order of updates arbitrarily and nondeterministically (e.g., due to concurrent updates on some hardware platforms).

Parameters
  • x – an array with the values to be updated.

  • idx – a Numpy-style index, consisting of None, integers, slice objects, ellipses, ndarrays with integer dtypes, or a tuple of the above. A convenient syntactic sugar for forming indices is via the jax.ops.index object.

  • y – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].

  • indices_are_sorted – whether idx is known to be sorted

  • unique_indices – whether idx is known to be free of duplicates

Returns

An array.

>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.)
array([[1., 1., 1., 6., 6., 6.],
       [1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 6., 6., 6.],
       [1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 6., 6., 6.]], dtype=float32)