jax.numpy.where(condition, x=None, y=None, *, size=None, fill_value=None)[source]#

Return elements chosen from x or y depending on condition.

LAX-backend implementation of numpy.where().

At present, JAX does not support JIT-compilation of the single-argument form of jax.numpy.where() because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully. Alternatively, you can use the optional size keyword to statically specify the expected size of the output.

Special care is needed when the x or y input to jax.numpy.where() could have a value of NaN. Specifically, when a gradient is taken with jax.grad() (reverse-mode differentiation), a NaN in either x or y will propagate into the gradient, regardless of the value of condition. More information on this behavior and workarounds is available in the JAX FAQ: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

Original docstring below.


When only condition is provided, this function is a shorthand for np.asarray(condition).nonzero(). Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.

  • condition (array_like, bool) – Where True, yield x, otherwise yield y.

  • x (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.

  • y (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.

  • size (int, optional) – Only referenced when x and y are None. If specified, the indices of the first size elements of the result will be returned. If there are fewer elements than size indicates, the return value will be padded with fill_value.

  • fill_value (array_like, optional) – When size is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with fill_value, which defaults to zero.


out – An array with elements from x where condition is True, and elements from y elsewhere.

Return type