jax.numpy.place#

jax.numpy.place(*args, **kwargs)[source]#

Change elements of an array based on conditional and input values.

LAX-backend implementation of numpy.place().

Numpy function numpy.place() is not available in JAX and will raise a NotImplementedError, because np.place modifies its arguments in-place, and in JAX arrays are immutable. A JAX-compatible approach to array updates can be found in jax.numpy.ndarray.at.

Original docstring below.

Similar to np.copyto(arr, vals, where=mask), the difference is that place uses the first N elements of vals, where N is the number of True values in mask, while copyto uses the elements where mask is True.

Note that extract does the exact opposite of place.