jax.lax.condΒΆ
-
jax.lax.
cond
(pred, true_fun, false_fun, operand)[source]ΒΆ Conditionally apply
true_fun
orfalse_fun
.cond()
has equivalent semantics to this Python implementation:def cond(pred, true_fun, false_fun, operand): if pred: return true_fun(operand) else: return false_fun(operand)
pred
must be a scalar type.Functions
true_fun
/false_fun
may not need to refer to anoperand
to compute their result, but one must still be provided to thecond
call and be accepted by both the branch functions, e.g.:jax.lax.cond( get_predicate_value(), lambda _: 23, lambda _: 42, operand=None)
- Parameters
pred β Boolean scalar type, indicating which branch function to apply.
true_fun (
Callable
) β Function (A -> B), to be applied ifpred
is True.false_fun (
Callable
) β Function (A -> B), to be applied ifpred
is False.operand β Operand (A) input to either branch depending on
pred
. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
- Returns
Value (B) of either
true_fun(operand)
orfalse_fun(operand)
, depending on the value ofpred
. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.