jax.Array.choose

Contents

jax.Array.choose#

abstract Array.choose(choices, out=None, mode='raise')[source]#

Construct an array choosing from elements of multiple arrays.

Refer to jax.numpy.choose() for the full documentation.

Parameters:
  • self (Array)

  • choices (Sequence[ArrayLike])

  • out (None)

  • mode (str)

Return type:

Array