jax.lax.cond

Contents

jax.lax.cond#

jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>, linear=None)[source]#

Conditionally apply true_fun or false_fun.

Wraps XLA’s Conditional operator.

Provided arguments are correctly typed, cond() has equivalent semantics to this Python implementation, where pred must be a scalar type:

def cond(pred, true_fun, false_fun, *operands):
  if pred:
    return true_fun(*operands)
  else:
    return false_fun(*operands)

In contrast with jax.lax.select(), using cond indicates that only one of the two branches is executed (up to compiler rewrites and optimizations). However, when transformed with vmap() to operate over a batch of predicates, cond is converted to select().

Parameters:
  • pred – Boolean scalar type, indicating which branch function to apply.

  • true_fun (Callable) – Function (A -> B), to be applied if pred is True.

  • false_fun (Callable) – Function (A -> B), to be applied if pred is False.

  • operands – Operands (A) input to either branch depending on pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.

Returns:

Value (B) of either true_fun(*operands) or false_fun(*operands), depending on the value of pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.