jax.numpy.put#
- jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[source]#
Replaces specified elements of an array with given values.
LAX-backend implementation of
numpy.put()
.The semantics of
numpy.put()
is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thusjax.numpy.put()
adds theinplace
parameter, which must be set toFalse
by the user as a reminder of this API difference.Original docstring below.
The indexing works on the flattened target array. put is roughly equivalent to:
a.flat[ind] = v
- Parameters:
a (ndarray) – Target array.
ind (array_like) – Target indices, interpreted as integers.
v (array_like) – Values to place in a at target indices. If v is shorter than ind it will be repeated as necessary.
mode ({'raise', 'wrap', 'clip'}, optional) –
Specifies how out-of-bounds indices will behave.
’raise’ – raise an error (default)
’wrap’ – wrap around
’clip’ – clip to the range
’clip’ mode means that all indices that are too large are replaced by the index that addresses the last element along that axis. Note that this disables indexing with negative numbers. In ‘raise’ mode, if an exception occurs the target array may still be modified.
inplace (bool, default=True) – If left to its default value of True, JAX will raise an error. This is because the semantics of
numpy.put()
are to modify the array in-place, which is not possible in JAX due to the immutability of JAX arrays.
- Return type:
Array