jax.lax.select_n#
- jax.lax.select_n(which, *cases)[source]#
Selects array values from multiple cases.
Generalizes XLA’s Select operator. Unlike XLA’s version, the operator is variadic and can select from many cases using an integer pred.
- Parameters:
which (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – determines which case should be returned. Must be an array containing either a boolean or integer values. May either be a scalar or have shape matchingcases
. For each array element, the value ofwhich
determines which ofcases
is taken.which
must be in the range[0 .. len(cases))
; for values outside that range the behavior is implementation-defined.*cases (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – a non-empty list of array cases. All must have equal dtypes and equal shapes.
- Return type:
- Returns:
An array with shape and dtype equal to the cases, whose values are chosen according to
which
.