jax.lax.select_n

Contents

jax.lax.select_n#

jax.lax.select_n(which, *cases)[source]#

Selects array values from multiple cases.

Generalizes XLA’s Select operator. Unlike XLA’s version, the operator is variadic and can select from many cases using an integer pred.

Parameters:
  • which (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – determines which case should be returned. Must be an array containing either a boolean or integer values. May either be a scalar or have shape matching cases. For each array element, the value of which determines which of cases is taken. which must be in the range [0 .. len(cases)); for values outside that range the behavior is implementation-defined.

  • *cases (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – a non-empty list of array cases. All must have equal dtypes and equal shapes.

Return type:

Array

Returns:

An array with shape and dtype equal to the cases, whose values are chosen according to which.