jax.numpy.take#

jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

Take elements from an array along an axis.

LAX-backend implementation of numpy.take().

The JAX version adds several extra parameters, described below, which are forwarded to jax.lax.gather() for finer control over indexing.

Original docstring below.

When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as np.take(arr, indices, axis=3) is equivalent to arr[:,:,:,indices,...].

Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of ii, jj, and kk to a tuple of indices:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
    for jj in ndindex(Nj):
        for kk in ndindex(Nk):
            out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
  • a (array_like (Ni..., M, Nk...)) – The source array.

  • indices (array_like (Nj...)) – The indices of the values to extract.

  • axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.

  • mode (string, default="fill") – Out-of-bounds indexing mode. The default mode=”fill” returns invalid values (e.g. NaN) for out-of bounds indices. See jax.numpy.ndarray.at for more discussion of out-of-bounds indexing in JAX.

  • unique_indices (bool, default=False) – If True, the implementation will assume that the indices are unique, which can result in more efficient execution on some backends.

  • indices_are_sorted (bool, default=False) – If True, the implementation will assume that the indices are sorted in ascending order, which can lead to more efficient execution on some backends.

  • fill_value (optional) – The fill value to return for out-of-bounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

Returns

out – The returned array has the same type as a.

Return type

ndarray (Ni…, Nj…, Nk…)