jax.lax.condΒΆ

jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[source]ΒΆ

Conditionally apply true_fun or false_fun.

cond() has equivalent semantics to this Python implementation:

def cond(pred, true_fun, false_fun, *operands):
  if pred:
    return true_fun(*operands)
  else:
    return false_fun(*operands)

pred must be a scalar type.

Parameters
  • pred – Boolean scalar type, indicating which branch function to apply.

  • true_fun (Callable) – Function (A -> B), to be applied if pred is True.

  • false_fun (Callable) – Function (A -> B), to be applied if pred is False.

  • operands – Operands (A) input to either branch depending on pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.

Returns

Value (B) of either true_fun(*operands) or false_fun(*operands), depending on the value of pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.