jax.numpy.put

Contents

jax.numpy.put#

jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[source]#

Replaces specified elements of an array with given values.

LAX-backend implementation of numpy.put().

The semantics of numpy.put() is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thus jax.numpy.put() adds the inplace parameter, which must be set to False by the user as a reminder of this API difference.

Original docstring below.

The indexing works on the flattened target array. put is roughly equivalent to:

a.flat[ind] = v
Parameters:
  • a (ndarray) – Target array.

  • ind (array_like) – Target indices, interpreted as integers.

  • v (array_like) – Values to place in a at target indices. If v is shorter than ind it will be repeated as necessary.

  • mode ({'raise', 'wrap', 'clip'}, optional) –

    Specifies how out-of-bounds indices will behave.

    • ’raise’ – raise an error (default)

    • ’wrap’ – wrap around

    • ’clip’ – clip to the range

    ’clip’ mode means that all indices that are too large are replaced by the index that addresses the last element along that axis. Note that this disables indexing with negative numbers. In ‘raise’ mode, if an exception occurs the target array may still be modified.

  • inplace (bool, default=True) – If left to its default value of True, JAX will raise an error. This is because the semantics of numpy.put() are to modify the array in-place, which is not possible in JAX due to the immutability of JAX arrays.

Return type:

Array