jax.numpy.place#

jax.numpy.place(arr, mask, vals, *, inplace=True)[source]#

Update array elements based on a mask.

JAX implementation of numpy.place().

The semantics of numpy.place() are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

Parameters:
  • arr (ArrayLike) – array into which values will be placed.

  • mask (ArrayLike) – boolean mask with the same size as arr.

  • vals (ArrayLike) – values to be inserted into arr at the locations indicated by mask. If too many values are supplied, they will be truncated. If not enough values are supplied, they will be repeated.

  • inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.

Returns:

A copy of arr with masked values set to entries from vals.

Return type:

Array

See also

Examples

>>> x = jnp.zeros((3, 5), dtype=int)
>>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape)
>>> mask
Array([[ True, False, False,  True, False],
       [False,  True, False, False,  True],
       [False, False,  True, False, False]], dtype=bool)

Placing a scalar value:

>>> jnp.place(x, mask, 1, inplace=False)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

In this case, jnp.place is similar to the masked array update syntax:

>>> x.at[mask].set(1)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

place differs when placing values from an array. The array is repeated to fill the masked entries:

>>> vals = jnp.array([1, 3, 5])
>>> jnp.place(x, mask, vals, inplace=False)
Array([[1, 0, 0, 3, 0],
       [0, 5, 0, 0, 1],
       [0, 0, 3, 0, 0]], dtype=int32)