jax.lax.condΒΆ

jax.lax.cond(pred, true_fun, false_fun, operand)[source]ΒΆ

Conditionally apply true_fun or false_fun.

cond() has equivalent semantics to this Python implementation:

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)

pred must be a scalar type.

Functions true_fun/false_fun may not need to refer to an operand to compute their result, but one must still be provided to the cond call and be accepted by both the branch functions, e.g.:

jax.lax.cond(
    get_predicate_value(),
    lambda _: 23,
    lambda _: 42,
    operand=None)
Parameters
  • pred – Boolean scalar type, indicating which branch function to apply.

  • true_fun (Callable) – Function (A -> B), to be applied if pred is True.

  • false_fun (Callable) – Function (A -> B), to be applied if pred is False.

  • operand – Operand (A) input to either branch depending on pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.

Returns

Value (B) of either true_fun(operand) or false_fun(operand), depending on the value of pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.