jax.numpy.select

Contents

jax.numpy.select#

jax.numpy.select(condlist, choicelist, default=0)[source]#

Return an array drawn from elements in choicelist, depending on conditions.

LAX-backend implementation of numpy.select().

Original docstring below.

Parameters:
  • condlist (list of bool ndarrays) – The list of conditions which determine from which array in choicelist the output elements are taken. When multiple conditions are satisfied, the first one encountered in condlist is used.

  • choicelist (list of ndarrays) – The list of arrays from which the output elements are taken. It has to be of the same length as condlist.

  • default (scalar, optional) – The element inserted in output when all conditions evaluate to False.

Returns:

output – The output at position m is the m-th element of the array in choicelist where the m-th element of the corresponding array in condlist is True.

Return type:

ndarray