jax.numpy.whereΒΆ
-
jax.numpy.
where
(condition, x=None, y=None)[source]ΒΆ Return elements chosen from x or y depending on condition.
LAX-backend implementation of
where()
.At present, JAX does not support JIT-compilation of the single-argument form of
jax.numpy.where()
because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully.Original docstring below.
Note
When only condition is provided, this function is a shorthand for
np.asarray(condition).nonzero()
. Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.- Parameters
condition (array_like, bool) β Where True, yield x, otherwise yield y.
x (array_like) β Values from which to choose. x, y and condition need to be broadcastable to some shape.
y (array_like) β Values from which to choose. x, y and condition need to be broadcastable to some shape.
- Returns
out β An array with elements from x where condition is True, and elements from y elsewhere.
- Return type