jax.numpy.take
jax.numpy.take#
- jax.numpy.take(a, indices, axis=None, out=None, mode=None)[source]#
Take elements from an array along an axis.
LAX-backend implementation of
numpy.take()
.In the JAX version, the
mode
argument defaults to a special mode ("fill"
) that returns invalid values (e.g., NaN) for out-of-bounds indices. Seejax.numpy.ndarray.at
for more discussion of out-of-bounds indexing in JAX.Original docstring below.
When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as
np.take(arr, indices, axis=3)
is equivalent toarr[:,:,:,indices,...]
.Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of
ii
,jj
, andkk
to a tuple of indices:Ni, Nk = a.shape[:axis], a.shape[axis+1:] Nj = indices.shape for ii in ndindex(Ni): for jj in ndindex(Nj): for kk in ndindex(Nk): out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
- Parameters
a (array_like (Ni..., M, Nk...)) – The source array.
indices (array_like (Nj...)) – The indices of the values to extract.
axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.
mode ({'raise', 'wrap', 'clip'}, optional) –
Specifies how out-of-bounds indices will behave.
’raise’ – raise an error (default)
’wrap’ – wrap around
’clip’ – clip to the range
’clip’ mode means that all indices that are too large are replaced by the index that addresses the last element along that axis. Note that this disables indexing with negative numbers.
- Returns
out – The returned array has the same type as a.
- Return type
ndarray (Ni…, Nj…, Nk…)