jax.lax.switch

Contents

jax.lax.switch#

jax.lax.switch(index, branches, *operands, operand=<object object>)[source]#

Apply exactly one of branches given by index.

If index is out of bounds, it is clamped to within bounds.

Has the semantics of the following Python:

def switch(index, branches, *operands):
  index = clamp(0, index, len(branches) - 1)
  return branches[index](*operands)

Internally this wraps XLA’s Conditional operator. However, when transformed with vmap() to operate over a batch of predicates, cond is converted to select().

Parameters:
  • index – Integer scalar type, indicating which branch function to apply.

  • branches (Sequence[Callable]) – Sequence of functions (A -> B) to be applied based on index.

  • operands – Operands (A) input to whichever branch is applied.

Returns:

Value (B) of branch(*operands) for the branch that was selected based on index.