jax.lax.select#
- jax.lax.select(pred, on_true, on_false)[source]#
Selects between two branches based on a boolean predicate.
Wraps XLA’s Select operator.
In general
select()
leads to evaluation of both branches, although the compiler may elide computations if possible. For a similar function that usually evaluates only a single branch, seecond()
.- Parameters
pred (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – boolean arrayon_true (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – array containing entries to return wherepred
is True. Must have the same shape aspred
, and the same shape and dtype ason_false
.on_false (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – array containing entries to return wherepred
is False. Must have the same shape aspred
, and the same shape and dtype ason_true
.
- Returns
array with same shape and dtype as
on_true
andon_false
.- Return type
result