jax.numpy.choose#
- jax.numpy.choose(a, choices, out=None, mode='raise')[source]#
Construct an array by stacking slices of choice arrays.
JAX implementation of
numpy.choose()
.The semantics of this function can be confusing, but in the simplest case where
a
is a one-dimensional array,choices
is a two-dimensional array, and all entries ofa
are in-bounds (i.e.0 <= a_i < len(choices)
), then the function is equivalent to the following:def choose(a, choices): return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])
In the more general case,
a
may have any number of dimensions andchoices
may be an arbitrary sequence of broadcast-compatible arrays. In this case, again for in-bound indices, the logic is equivalent to:def choose(a, choices): a, *choices = jnp.broadcast_arrays(a, *choices) choices = jnp.array(choices) return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])
The only additional complexity comes from the
mode
argument, which controls the behavior for out-of-bound indices ina
as described below.- Parameters:
a (ArrayLike) – an N-dimensional array of integer indices.
choices (Array | np.ndarray | Sequence[ArrayLike]) – an array or sequence of arrays. All arrays in the sequence must be mutually broadcast compatible with
a
.out (None | None) – unused by JAX
mode (str) – specify the out-of-bounds indexing mode; one of
'raise'
(default),'wrap'
, or'clip'
. Note that the default mode of'raise'
is not compatible with JAX transformations.
- Returns:
an array containing stacked slices from
choices
at the indices specified bya
. The shape of the result isbroadcast_shapes(a.shape, *(c.shape for c in choices))
.- Return type:
See also
jax.lax.switch()
: choose between N functions based on an index.
Examples
Here is the simplest case of a 1D index array with a 2D choice array, in which case this chooses the indexed value from each column:
>>> choices = jnp.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [ 9, 10, 11, 12]]) >>> a = jnp.array([2, 0, 1, 0]) >>> jnp.choose(a, choices) Array([9, 2, 7, 4], dtype=int32)
The
mode
argument specifies what to do with out-of-bound indices; options are to eitherwrap
orclip
:>>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound >>> jnp.choose(a2, choices, mode='clip') Array([ 9, 2, 7, 12], dtype=int32) >>> jnp.choose(a2, choices, mode='wrap') Array([9, 2, 7, 8], dtype=int32)
In the more general case,
choices
may be a sequence of array-like objects with any broadcast-compatible shapes.>>> choice_1 = jnp.array([1, 2, 3, 4]) >>> choice_2 = 99 >>> choice_3 = jnp.array([[10], ... [20], ... [30]]) >>> a = jnp.array([[0, 1, 2, 0], ... [1, 2, 0, 1], ... [2, 0, 1, 2]]) >>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap') Array([[ 1, 99, 10, 4], [99, 20, 3, 99], [30, 2, 99, 30]], dtype=int32)