- jax.numpy.where(acondition=None, if_true=None, if_false=None, /, *, size=None, fill_value=None, condition=<object object>, x=<object object>, y=<object object>)#
Return elements chosen from x or y depending on condition.
LAX-backend implementation of
At present, JAX does not support JIT-compilation of the single-argument form of
jax.numpy.where()because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully. Alternatively, you can use the optional
sizekeyword to statically specify the expected size of the output.
Special care is needed when the
jax.numpy.where()could have a value of NaN. Specifically, when a gradient is taken with
jax.grad()(reverse-mode differentiation), a NaN in either
ywill propagate into the gradient, regardless of the value of
condition. More information on this behavior and workarounds is available in the JAX FAQ: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
Original docstring below.
When only condition is provided, this function is a shorthand for
np.asarray(condition).nonzero(). Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.
condition (array_like, bool) – Where True, yield x, otherwise yield y.
x (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
y (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
size (int, optional) – Only referenced when
None. If specified, the indices of the first
sizeelements of the result will be returned. If there are fewer elements than
sizeindicates, the return value will be padded with
fill_value (array_like, optional) – When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled with
fill_value, which defaults to zero.
out – An array with elements from x where condition is True, and elements from y elsewhere.
- Return type: