jax.lax.switch#
- jax.lax.switch(index, branches, *operands, operand=<object object>)[source]#
Apply exactly one of the
branches
given byindex
.If
index
is out of bounds, it is clamped to within bounds.Has the semantics of the following Python:
def switch(index, branches, *operands): index = clamp(0, index, len(branches) - 1) return branches[index](*operands)
Internally this wraps XLA’s Conditional operator. However, when transformed with
vmap()
to operate over a batch of predicates,cond
is converted toselect()
.- Parameters:
index – Integer scalar type, indicating which branch function to apply.
branches (Sequence[Callable]) – Sequence of functions (A -> B) to be applied based on
index
. All branches must return the same output structure.operands – Operands (A) input to whichever branch is applied.
- Returns:
Value (B) of
branch(*operands)
for the branch that was selected based onindex
.