- jax.lax.switch(index, branches, *operands, operand=<object object>)#
Apply exactly one of
indexis out of bounds, it is clamped to within bounds.
Has the semantics of the following Python:
def switch(index, branches, *operands): index = clamp(0, index, len(branches) - 1) return branches[index](*operands)
Value (B) of
branch(*operands)for the branch that was selected based on