jax.lax.switch(index, branches, operand)[source]ΒΆ

Apply exactly one of branches given by index.

If index is out of bounds, it is clamped to within bounds.

Has the semantics of the following Python:

def switch(index, branches, operand):
  index = clamp(0, index, len(branches) - 1)
  return branches[index](operand)
  • index – Integer scalar type, indicating which branch function to apply.

  • branches (Sequence[Callable]) – Sequence of functions (A -> B) to be applied based on index.

  • operand – Operand (A) input to whichever branch is applied.