jax.numpy.place#
- jax.numpy.place(arr, mask, vals, *, inplace=True)[source]#
Update array elements based on a mask.
JAX implementation of
numpy.place()
.The semantics of
numpy.place()
are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplace
parameter which must be set to False` by the user as a reminder of this API difference.- Parameters:
arr (ArrayLike) – array into which values will be placed.
mask (ArrayLike) – boolean mask with the same size as
arr
.vals (ArrayLike) – values to be inserted into
arr
at the locations indicated by mask. If too many values are supplied, they will be truncated. If not enough values are supplied, they will be repeated.inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
- Returns:
A copy of
arr
with masked values set to entries from vals.- Return type:
See also
jax.numpy.put()
: put elements into an array at numerical indices.jax.numpy.ndarray.at()
: array updates using NumPy-style indexing
Examples
>>> x = jnp.zeros((3, 5), dtype=int) >>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape) >>> mask Array([[ True, False, False, True, False], [False, True, False, False, True], [False, False, True, False, False]], dtype=bool)
Placing a scalar value:
>>> jnp.place(x, mask, 1, inplace=False) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
In this case,
jnp.place
is similar to the masked array update syntax:>>> x.at[mask].set(1) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
place
differs when placing values from an array. The array is repeated to fill the masked entries:>>> vals = jnp.array([1, 3, 5]) >>> jnp.place(x, mask, vals, inplace=False) Array([[1, 0, 0, 3, 0], [0, 5, 0, 0, 1], [0, 0, 3, 0, 0]], dtype=int32)