jax.lax.select

Contents

jax.lax.select#

jax.lax.select(pred, on_true, on_false)[source]#

Selects between two branches based on a boolean predicate.

Wraps XLA’s Select operator.

In general select() leads to evaluation of both branches, although the compiler may elide computations if possible. For a similar function that usually evaluates only a single branch, see cond().

Parameters:
Returns:

array with same shape and dtype as on_true and on_false.

Return type:

result