jax.numpy.choose#

jax.numpy.choose(a, choices, out=None, mode='raise')[source]#

Construct an array by stacking slices of choice arrays.

JAX implementation of numpy.choose().

The semantics of this function can be confusing, but in the simplest case where a is a one-dimensional array, choices is a two-dimensional array, and all entries of a are in-bounds (i.e. 0 <= a_i < len(choices)), then the function is equivalent to the following:

def choose(a, choices):
  return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])

In the more general case, a may have any number of dimensions and choices may be an arbitrary sequence of broadcast-compatible arrays. In this case, again for in-bound indices, the logic is equivalent to:

def choose(a, choices):
  a, *choices = jnp.broadcast_arrays(a, *choices)
  choices = jnp.array(choices)
  return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])

The only additional complexity comes from the mode argument, which controls the behavior for out-of-bound indices in a as described below.

Parameters:
  • a (ArrayLike) – an N-dimensional array of integer indices.

  • choices (Array | np.ndarray | Sequence[ArrayLike]) – an array or sequence of arrays. All arrays in the sequence must be mutually broadcast compatible with a.

  • out (None | None) – unused by JAX

  • mode (str) – specify the out-of-bounds indexing mode; one of 'raise' (default), 'wrap', or 'clip'. Note that the default mode of 'raise' is not compatible with JAX transformations.

Returns:

an array containing stacked slices from choices at the indices specified by a. The shape of the result is broadcast_shapes(a.shape, *(c.shape for c in choices)).

Return type:

Array

See also

Examples

Here is the simplest case of a 1D index array with a 2D choice array, in which case this chooses the indexed value from each column:

>>> choices = jnp.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [ 9, 10, 11, 12]])
>>> a = jnp.array([2, 0, 1, 0])
>>> jnp.choose(a, choices)
Array([9, 2, 7, 4], dtype=int32)

The mode argument specifies what to do with out-of-bound indices; options are to either wrap or clip:

>>> a2 = jnp.array([2, 0, 1, 4])  # last index out-of-bound
>>> jnp.choose(a2, choices, mode='clip')
Array([ 9,  2,  7, 12], dtype=int32)
>>> jnp.choose(a2, choices, mode='wrap')
Array([9, 2, 7, 8], dtype=int32)

In the more general case, choices may be a sequence of array-like objects with any broadcast-compatible shapes.

>>> choice_1 = jnp.array([1, 2, 3, 4])
>>> choice_2 = 99
>>> choice_3 = jnp.array([[10],
...                       [20],
...                       [30]])
>>> a = jnp.array([[0, 1, 2, 0],
...                [1, 2, 0, 1],
...                [2, 0, 1, 2]])
>>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap')
Array([[ 1, 99, 10,  4],
       [99, 20,  3, 99],
       [30,  2, 99, 30]], dtype=int32)