jax.numpy.where

jax.numpy.where(condition, x=None, y=None)[source]

Return elements chosen from x or y depending on condition.

LAX-backend implementation of where(). At present, JAX does not support JIT-compilation of the single-argument form of jax.numpy.where() because its output shape is data-dependent. The three-argument form does not have a data-dependent shape and can be JIT-compiled successfully.

Original docstring below.

where(condition, [x, y])

Note

When only condition is provided, this function is a shorthand for np.asarray(condition).nonzero(). Using nonzero directly should be preferred, as it behaves correctly for subclasses. The rest of this documentation covers only the case where all three arguments are provided.

conditionarray_like, bool

Where True, yield x, otherwise yield y.

x, yarray_like

Values from which to choose. x, y and condition need to be broadcastable to some shape.

outndarray

An array with elements from x where condition is True, and elements from y elsewhere.

choose nonzero : The function that is called when x and y are omitted

If all the arrays are 1-D, where is equivalent to:

[xv if c else yv
 for c, xv, yv in zip(condition, x, y)]
>>> a = np.arange(10)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> np.where(a < 5, a, 10*a)
array([ 0,  1,  2,  3,  4, 50, 60, 70, 80, 90])

This can be used on multidimensional arrays too:

>>> np.where([[True, False], [True, True]],
...          [[1, 2], [3, 4]],
...          [[9, 8], [7, 6]])
array([[1, 8],
       [3, 4]])

The shapes of x, y, and the condition are broadcast together:

>>> x, y = np.ogrid[:3, :4]
>>> np.where(x < y, x, 10 + y)  # both x and 10+y are broadcast
array([[10,  0,  0,  0],
       [10, 11,  1,  1],
       [10, 11, 12,  2]])
>>> a = np.array([[0, 1, 2],
...               [0, 2, 4],
...               [0, 3, 6]])
>>> np.where(a < 4, a, -1)  # -1 is broadcast
array([[ 0,  1,  2],
       [ 0,  2, -1],
       [ 0,  3, -1]])