jax.numpy.where#
- jax.numpy.where(condition, x=None, y=None, *, size=None, fill_value=None)[source]#
Return elements chosen from x or y depending on condition.
LAX-backend implementation of
numpy.where()
.At present, JAX does not support JIT-compilation of the single-argument form of
jax.numpy.where()
because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully. Alternatively, you can use the optionalsize
keyword to statically specify the expected size of the output.Special care is needed when the
x
ory
input tojax.numpy.where()
could have a value of NaN. Specifically, when a gradient is taken withjax.grad()
(reverse-mode differentiation), a NaN in eitherx
ory
will propagate into the gradient, regardless of the value ofcondition
. More information on this behavior and workarounds is available in the JAX FAQ: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-whereOriginal docstring below.
Note
When only condition is provided, this function is a shorthand for
np.asarray(condition).nonzero()
. Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.- Parameters
condition (array_like, bool) – Where True, yield x, otherwise yield y.
x (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
y (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
size (int, optional) – Only referenced when
x
andy
areNone
. If specified, the indices of the firstsize
elements of the result will be returned. If there are fewer elements thansize
indicates, the return value will be padded withfill_value
.fill_value (array_like, optional) – When
size
is specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value
, which defaults to zero.
- Returns
out – An array with elements from x where condition is True, and elements from y elsewhere.
- Return type
ndarray