- jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)#
Take elements from an array along an axis.
LAX-backend implementation of
The JAX version adds several extra parameters, described below, which are forwarded to
jax.lax.gather()for finer control over indexing.
Original docstring below.
When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as
np.take(arr, indices, axis=3)is equivalent to
Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of
kkto a tuple of indices:
Ni, Nk = a.shape[:axis], a.shape[axis+1:] Nj = indices.shape for ii in ndindex(Ni): for jj in ndindex(Nj): for kk in ndindex(Nk): out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
a (array_like (Ni..., M, Nk...)) – The source array.
indices (array_like (Nj...)) – The indices of the values to extract.
axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.
mode (string, default="fill") – Out-of-bounds indexing mode. The default mode=”fill” returns invalid values (e.g. NaN) for out-of bounds indices. See
jax.numpy.ndarray.atfor more discussion of out-of-bounds indexing in JAX.
unique_indices (bool, default=False) – If True, the implementation will assume that the indices are unique, which can result in more efficient execution on some backends.
indices_are_sorted (bool, default=False) – If True, the implementation will assume that the indices are sorted in ascending order, which can lead to more efficient execution on some backends.
fill_value (optional) – The fill value to return for out-of-bounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.
out – The returned array has the same type as a.
- Return type
ndarray (Ni…, Nj…, Nk…)