jax.lax.cond

jax.lax.cond(pred, true_operand, true_fun, false_operand, false_fun)[source]

Conditionally apply true_fun or false_fun.

Has equivalent semantics to this Python implementation:

def cond(pred, true_operand, true_fun, false_operand, false_fun):
  if pred:
    return true_fun(true_operand)
  else:
    return false_fun(false_operand)

Pred has to be a scalar type, collection types (list, tuple) are not supported