jax.lax.switchΒΆ
-
jax.lax.
switch
(index, branches, operand)[source]ΒΆ Apply exactly one of
branches
given byindex
.If
index
is out of bounds, it is clamped to within bounds.Has the semantics of the following Python:
def switch(index, branches, operand): index = clamp(0, index, len(branches) - 1) return branches[index](operand)