jax.lax.select

jax.lax.select(pred, on_true, on_false)[source]

Wraps XLA’s Select operator.

Parameters
  • pred (Any) –

  • on_true (Any) –

  • on_false (Any) –

Return type

Any