jax.lax.cond#
- jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[source]#
Conditionally apply
true_fun
orfalse_fun
.Wraps XLA’s Conditional operator.
Provided arguments are correctly typed,
cond()
has equivalent semantics to this Python implementation, wherepred
must be a scalar type:def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
In contrast with
jax.lax.select()
, usingcond
indicates that only one of the two branches is executed (up to compiler rewrites and optimizations). However, when transformed withvmap()
to operate over a batch of predicates,cond
is converted toselect()
.- Parameters:
pred – Boolean scalar type, indicating which branch function to apply.
true_fun (Callable) – Function (A -> B), to be applied if
pred
is True.false_fun (Callable) – Function (A -> B), to be applied if
pred
is False.operands – Operands (A) input to either branch depending on
pred
. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
- Returns:
Value (B) of either
true_fun(*operands)
orfalse_fun(*operands)
, depending on the value ofpred
. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.