jax.numpy.where

Contents

jax.numpy.where#

jax.numpy.where(condition: Array | ndarray | bool_ | number | bool | int | float | complex, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) tuple[Array, ...][source]#
jax.numpy.where(condition: Array | ndarray | bool_ | number | bool | int | float | complex, x: Array | ndarray | bool_ | number | bool | int | float | complex, y: Array | ndarray | bool_ | number | bool | int | float | complex, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array
jax.numpy.where(condition: Array | ndarray | bool_ | number | bool | int | float | complex, x: ArrayLike | None = None, y: ArrayLike | None = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array | tuple[Array, ...]

Select elements from two arrays based on a condition.

JAX implementation of numpy.where().

Note

when only condition is provided, jnp.where(condition) is equivalent to jnp.nonzero(condition). For that case, refer to the documentation of jax.numpy.nonzero(). The docstring below focuses on the case where x and y are specified.

The three-term version of jnp.where lowers to jax.lax.select().

Parameters:
  • condition – boolean array. Must be broadcast-compatible with x and y when they are specified.

  • x – arraylike. Should be broadcast-compatible with condition and y, and typecast-compatible with y.

  • y – arraylike. Should be broadcast-compatible with condition and x, and typecast-compatible with x.

  • size – integer, only referenced when x and y are None. For details, see jax.numpy.nonzero().

  • fill_value – only referenced when x and y are None. For details, see jax.numpy.nonzero().

Returns:

An array of dtype jnp.result_type(x, y) with values drawn from x where condition is True, and from y where condition is False. If ``x and y are None, the function behaves differently; see :func:`jax.numpy.nonzero for a description of the return type.

Notes

Special care is needed when the x or y input to jax.numpy.where() could have a value of NaN. Specifically, when a gradient is taken with jax.grad() (reverse-mode differentiation), a NaN in either x or y will propagate into the gradient, regardless of the value of condition. More information on this behavior and workarounds is available in the JAX FAQ.

Examples

When x and y are not provided, where behaves equivalently to jax.numpy.nonzero():

>>> x = jnp.arange(10)
>>> jnp.where(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
>>> jnp.nonzero(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)

When x and y are provided, where selects between them based on the specified condition:

>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)