jax.numpy.put#

jax.numpy.put(*args, **kwargs)[source]#

Replaces specified elements of an array with given values.

LAX-backend implementation of numpy.put().

Numpy function numpy.put() is not available in JAX and will raise a NotImplementedError, because np.put modifies its arguments in-place, and in JAX arrays are immutable. A JAX-compatible approach to array updates can be found in jax.numpy.ndarray.at.

Original docstring below.

The indexing works on the flattened target array. put is roughly equivalent to:

a.flat[ind] = v