
jax.numpy.choose(a, choices, out=None, mode='raise')[source]#

Construct an array by stacking slices of choice arrays.

JAX implementation of numpy.choose().

The semantics of this function can be confusing, but in the simplest case where a is a one-dimensional array, choices is a two-dimensional array, and all entries of a are in-bounds (i.e. 0 <= a_i < len(choices)), then the function is equivalent to the following:

def choose(a, choices):
  return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])

In the more general case, a may have any number of dimensions and choices may be an arbitrary sequence of broadcast-compatible arrays. In this case, again for in-bound indices, the logic is equivalent to:

def choose(a, choices):
  a, *choices = jnp.broadcast_arrays(a, *choices)
  choices = jnp.array(choices)
  return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])

The only additional complexity comes from the mode argument, which controls the behavior for out-of-bound indices in a as described below.

  • a (ArrayLike) – an N-dimensional array of integer indices.

  • choices (Array | np.ndarray | Sequence[ArrayLike]) – an array or sequence of arrays. All arrays in the sequence must be mutually broadcast compatible with a.

  • out (None | None) – unused by JAX

  • mode (str) – specify the out-of-bounds indexing mode; one of 'raise' (default), 'wrap', or 'clip'. Note that the default mode of 'raise' is not compatible with JAX transformations.


an array containing stacked slices from choices at the indices specified by a. The shape of the result is broadcast_shapes(a.shape, *(c.shape for c in choices)).

Return type:


See also


Here is the simplest case of a 1D index array with a 2D choice array, in which case this chooses the indexed value from each column:

>>> choices = jnp.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [ 9, 10, 11, 12]])
>>> a = jnp.array([2, 0, 1, 0])
>>> jnp.choose(a, choices)
Array([9, 2, 7, 4], dtype=int32)

The mode argument specifies what to do with out-of-bound indices; options are to either wrap or clip:

>>> a2 = jnp.array([2, 0, 1, 4])  # last index out-of-bound
>>> jnp.choose(a2, choices, mode='clip')
Array([ 9,  2,  7, 12], dtype=int32)
>>> jnp.choose(a2, choices, mode='wrap')
Array([9, 2, 7, 8], dtype=int32)

In the more general case, choices may be a sequence of array-like objects with any broadcast-compatible shapes.

>>> choice_1 = jnp.array([1, 2, 3, 4])
>>> choice_2 = 99
>>> choice_3 = jnp.array([[10],
...                       [20],
...                       [30]])
>>> a = jnp.array([[0, 1, 2, 0],
...                [1, 2, 0, 1],
...                [2, 0, 1, 2]])
>>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap')
Array([[ 1, 99, 10,  4],
       [99, 20,  3, 99],
       [30,  2, 99, 30]], dtype=int32)