jax.lax.select

Contents

jax.lax.select#

jax.lax.select(pred, on_true, on_false)[source]#

Selects between two branches based on a boolean predicate.

Wraps XLA’s Select operator.

In general select() leads to evaluation of both branches, although the compiler may elide computations if possible. For a similar function that usually evaluates only a single branch, see cond().

Parameters:
  • pred (jax.typing.ArrayLike) – boolean array

  • on_true (jax.typing.ArrayLike) – array containing entries to return where pred is True. Must have the same shape as pred, and the same shape and dtype as on_false.

  • on_false (jax.typing.ArrayLike) – array containing entries to return where pred is False. Must have the same shape as pred, and the same shape and dtype as on_true.

Returns:

array with same shape and dtype as on_true and on_false.

Return type:

result