jax.lax.cond¶
-
jax.lax.
cond
(pred, true_operand, true_fun, false_operand, false_fun)[source]¶ Conditionally apply
true_fun
orfalse_fun
.Has equivalent semantics to this Python implementation:
def cond(pred, true_operand, true_fun, false_operand, false_fun): if pred: return true_fun(true_operand) else: return false_fun(false_operand)
Pred has to be a scalar type, collection types (list, tuple) are not supported