jax.numpy.take(a, indices, axis=None, out=None, mode=None)[source]

Take elements from an array along an axis.

LAX-backend implementation of take(). Original docstring below.

When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as np.take(arr, indices, axis=3) is equivalent to arr[:,:,:,indices,...].

Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of ii, jj, and kk to a tuple of indices:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
    for jj in ndindex(Nj):
        for kk in ndindex(Nk):
            out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
  • a (array_like (Ni..., M, Nk...)) – The source array.

  • indices (array_like (Nj...)) – The indices of the values to extract.

  • axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.

  • out (ndarray, optional (Ni..., Nj..., Nk...)) – If provided, the result will be placed in this array. It should be of the appropriate shape and dtype. Note that out is always buffered if mode=’raise’; use other modes for better performance.

  • mode ({'raise', 'wrap', 'clip'}, optional) – Specifies how out-of-bounds indices will behave.


out – The returned array has the same type as a.

Return type

ndarray (Ni…, Nj…, Nk…)

See also


Take elements using a boolean mask


equivalent method


Take elements by matching the array and the index arrays


By eliminating the inner loop in the description above, and using s_ to build simple slice objects, take can be expressed in terms of applying fancy indexing to each 1-d slice:

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
    for kk in ndindex(Nj):
        out[ii + s_[...,] + kk] = a[ii + s_[:,] + kk][indices]

For this reason, it is equivalent to (but faster than) the following use of apply_along_axis:

out = np.apply_along_axis(lambda a_1d: a_1d[indices], axis, a)


>>> a = [4, 3, 5, 7, 6, 8]
>>> indices = [0, 1, 4]
>>> np.take(a, indices)
array([4, 3, 6])

In this example if a is an ndarray, “fancy” indexing can be used.

>>> a = np.array(a)
>>> a[indices]
array([4, 3, 6])

If indices is not one dimensional, the output also has these dimensions.

>>> np.take(a, [[0, 1], [2, 3]])
array([[4, 3],
       [5, 7]])