jax.numpy.select(condlist, choicelist, default=0)[source]ΒΆ

Return an array drawn from elements in choicelist, depending on conditions.

LAX-backend implementation of select(). Original docstring below.

  • condlist (list of bool ndarrays) – The list of conditions which determine from which array in choicelist the output elements are taken. When multiple conditions are satisfied, the first one encountered in condlist is used.

  • choicelist (list of ndarrays) – The list of arrays from which the output elements are taken. It has to be of the same length as condlist.

  • default (scalar, optional) – The element inserted in output when all conditions evaluate to False.


output – The output at position m is the m-th element of the array in choicelist where the m-th element of the corresponding array in condlist is True.

Return type


See also


Return elements from one of two arrays depending on condition.

take(), choose(), compress(), diag(), diagonal()


>>> x = np.arange(10)
>>> condlist = [x<3, x>5]
>>> choicelist = [x, x**2]
>>> np.select(condlist, choicelist)
array([ 0,  1,  2, ..., 49, 64, 81])