jax.numpy.place

Contents

jax.numpy.place#

jax.numpy.place(arr, mask, vals, *, inplace=True)[source]#

Change elements of an array based on conditional and input values.

LAX-backend implementation of numpy.place().

The semantics of numpy.place() is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thus jax.numpy.place() adds the inplace parameter, which must be set to False by the user as a reminder of this API difference.

Original docstring below.

Similar to np.copyto(arr, vals, where=mask), the difference is that place uses the first N elements of vals, where N is the number of True values in mask, while copyto uses the elements where mask is True.

Note that extract does the exact opposite of place.

Parameters:
  • arr (ndarray) – Array to put data into.

  • mask (array_like) – Boolean mask array. Must have the same size as a.

  • vals (1-D sequence) – Values to put into a. Only the first N elements are used, where N is the number of True values in mask. If vals is smaller than N, it will be repeated, and if elements of a are to be masked, this sequence must be non-empty.

  • inplace (bool, default=True) – If left to its default value of True, JAX will raise an error. This is because the semantics of numpy.put() are to modify the array in-place, which is not possible in JAX due to the immutability of JAX arrays.

Return type:

Array