where(condition, x=None, y=None, *, size=None, fill_value=None)¶
Return elements chosen from x or y depending on condition.
LAX-backend implementation of
At present, JAX does not support JIT-compilation of the single-argument form of
jax.numpy.where()because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully. Alternatively, you can specify the optional
sizekeyword: if specified, the first
sizeTrue elements will be returned; if there are fewer True elements than
sizeindicates, the index arrays will be padded with
fill_value(default is 0.)
Original docstring below.
When only condition is provided, this function is a shorthand for
np.asarray(condition).nonzero(). Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.
condition (array_like, bool) – Where True, yield x, otherwise yield y.
x (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
y (array_like) – Values from which to choose. x, y and condition need to be broadcastable to some shape.
out – An array with elements from x where condition is True, and elements from y elsewhere.
- Return type