jax.lax.select_n(which, *cases)[source]#

Selects array values from multiple cases.

Generalizes XLA’s Select operator. Unlike XLA’s version, the operator is variadic and can select from many cases using an integer pred.

  • which (jax.typing.ArrayLike) – determines which case should be returned. Must be an array containing either a boolean or integer values. May either be a scalar or have shape matching cases. For each array element, the value of which determines which of cases is taken. which must be in the range [0 .. len(cases)); for values outside that range the behavior is implementation-defined.

  • *cases (jax.typing.ArrayLike) – a non-empty list of array cases. All must have equal dtypes and equal shapes.


An array with shape and dtype equal to the cases, whose values are chosen according to which.

Return type: