jax.numpy.place#
- jax.numpy.place(arr, mask, vals, *, inplace=True)[source]#
Change elements of an array based on conditional and input values.
LAX-backend implementation of
numpy.place()
.The semantics of
numpy.place()
is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thusjax.numpy.place()
adds theinplace
parameter, which must be set toFalse
by the user as a reminder of this API difference.Original docstring below.
Similar to
np.copyto(arr, vals, where=mask)
, the difference is that place uses the first N elements of vals, where N is the number of True values in mask, while copyto uses the elements where mask is True.Note that extract does the exact opposite of place.
- Parameters:
arr (ndarray) – Array to put data into.
mask (array_like) – Boolean mask array. Must have the same size as a.
vals (1-D sequence) – Values to put into a. Only the first N elements are used, where N is the number of True values in mask. If vals is smaller than N, it will be repeated, and if elements of a are to be masked, this sequence must be non-empty.
inplace (bool, default=True) – If left to its default value of True, JAX will raise an error. This is because the semantics of
numpy.put()
are to modify the array in-place, which is not possible in JAX due to the immutability of JAX arrays.
- Return type: