jax.numpy.choose#
- jax.numpy.choose(a, choices, out=None, mode='raise')[source]#
Construct an array from an index array and a list of arrays to choose from.
LAX-backend implementation of
numpy.choose()
.Original docstring below.
First of all, if confused or uncertain, definitely look at the Examples - in its full generality, this function is less simple than it might seem from the following code description (below ndi = numpy.lib.index_tricks):
np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)])
.But this omits some subtleties. Here is a fully general summary:
Given an “index” array (a) of integers and a sequence of
n
arrays (choices), a and each choice array are first broadcast, as necessary, to arrays of a common shape; calling these Ba and Bchoices[i], i = 0,…,n-1 we have that, necessarily,Ba.shape == Bchoices[i].shape
for eachi
. Then, a new array with shapeBa.shape
is created as follows:if
mode='raise'
(the default), then, first of all, each element ofa
(and thusBa
) must be in the range[0, n-1]
; now, suppose thati
(in that range) is the value at the(j0, j1, ..., jm)
position inBa
- then the value at the same position in the new array is the value inBchoices[i]
at that same position;if
mode='wrap'
, values in a (and thus Ba) may be any (signed) integer; modular arithmetic is used to map integers outside the range [0, n-1] back into that range; and then the new array is constructed as above;if
mode='clip'
, values in a (and thusBa
) may be any (signed) integer; negative integers are mapped to 0; values greater thann-1
are mapped ton-1
; and then the new array is constructed as above.
- Parameters:
a (int array) – This array must contain integers in
[0, n-1]
, wheren
is the number of choices, unlessmode=wrap
ormode=clip
, in which cases any integers are permissible.choices (sequence of arrays) – Choice arrays. a and all of the choices must be broadcastable to the same shape. If choices is itself an array (not recommended), then its outermost dimension (i.e., the one corresponding to
choices.shape[0]
) is taken as defining the “sequence”.mode ({'raise' (default), 'wrap', 'clip'}, optional) –
Specifies how indices outside
[0, n-1]
will be treated:’raise’ : an exception is raised
’wrap’ : value becomes value mod
n
’clip’ : values < 0 are mapped to 0, values > n-1 are mapped to n-1
out (
None
) –
- Returns:
merged_array – The merged result.
- Return type:
array