jax.numpy.put#
- jax.numpy.put(*args, **kwargs)[source]#
Replaces specified elements of an array with given values.
LAX-backend implementation of
numpy.put()
.Numpy function
numpy.put()
is not available in JAX and will raise aNotImplementedError
, becausenp.put
modifies its arguments in-place, and in JAX arrays are immutable. A JAX-compatible approach to array updates can be found injax.numpy.ndarray.at
.Original docstring below.
The indexing works on the flattened target array. put is roughly equivalent to:
a.flat[ind] = v