jax.numpy.selectΒΆ

jax.numpy.select(condlist, choicelist, default=0)[source]ΒΆ

Return an array drawn from elements in choicelist, depending on conditions.

LAX-backend implementation of select().

Original docstring below.

Parameters
  • condlist (list of bool ndarrays) – The list of conditions which determine from which array in choicelist the output elements are taken. When multiple conditions are satisfied, the first one encountered in condlist is used.

  • choicelist (list of ndarrays) – The list of arrays from which the output elements are taken. It has to be of the same length as condlist.

  • default (scalar, optional) – The element inserted in output when all conditions evaluate to False.

Returns

output – The output at position m is the m-th element of the array in choicelist where the m-th element of the corresponding array in condlist is True.

Return type

ndarray