jax.ops.index_update¶
-
jax.ops.
index_update
(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]¶ Pure equivalent of
x[idx] = y
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] = y
Note the index_update operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike NumPy’s
x[idx] = y
, if multiple indices refer to the same location it is undefined which update is chosen; JAX may choose the order of updates arbitrarily and nondeterministically (e.g., due to concurrent updates on some hardware platforms).- Parameters
x – an array with the values to be updated.
idx – a Numpy-style index, consisting of None, integers, slice objects, ellipses, ndarrays with integer dtypes, or a tuple of the above. A convenient syntactic sugar for forming indices is via the
jax.ops.index
object.y – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].
indices_are_sorted – whether idx is known to be sorted
unique_indices – whether idx is known to be free of duplicates
- Returns
An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.) array([[1., 1., 1., 6., 6., 6.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 6., 6., 6.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 6., 6., 6.]], dtype=float32)