jax.numpy.take

Contents

jax.numpy.take#

jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

Take elements from an array along an axis.

LAX-backend implementation of numpy.take().

By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the mode parameter (see below).

Original docstring below.

When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as np.take(arr, indices, axis=3) is equivalent to arr[:,:,:,indices,...].

Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of ii, jj, and kk to a tuple of indices:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
    for jj in ndindex(Nj):
        for kk in ndindex(Nk):
            out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters:
  • a (array_like (Ni..., M, Nk...)) – The source array.

  • indices (array_like (Nj...)) – The indices of the values to extract.

  • axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.

  • mode (string, default="fill") – Out-of-bounds indexing mode. The default mode=”fill” returns invalid values (e.g. NaN) for out-of bounds indices (see also fill_value below). For more discussion of mode options, see jax.numpy.ndarray.at.

  • fill_value (optional) – The fill value to return for out-of-bounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

  • unique_indices (bool, default=False) – If True, the implementation will assume that the indices are unique, which can result in more efficient execution on some backends.

  • indices_are_sorted (bool, default=False) – If True, the implementation will assume that the indices are sorted in ascending order, which can lead to more efficient execution on some backends.

  • out (None)

Returns:

out – The returned array has the same type as a.

Return type:

ndarray (Ni…, Nj…, Nk…)