jax.numpy.take¶
-
jax.numpy.
take
(a, indices, axis=None, out=None, mode=None)[source]¶ Take elements from an array along an axis.
LAX-backend implementation of
take()
. Original docstring below.When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as
np.take(arr, indices, axis=3)
is equivalent toarr[:,:,:,indices,...]
.Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of
ii
,jj
, andkk
to a tuple of indices:Ni, Nk = a.shape[:axis], a.shape[axis+1:] Nj = indices.shape for ii in ndindex(Ni):
- for jj in ndindex(Nj):
- for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
- Parameters
a (array_like (Ni..., M, Nk...)) – The source array.
indices (array_like (Nj...)) – The indices of the values to extract.
axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.
out (ndarray, optional (Ni..., Nj..., Nk...)) – If provided, the result will be placed in this array. It should be of the appropriate shape and dtype. Note that out is always buffered if mode=’raise’; use other modes for better performance.
mode ({'raise', 'wrap', 'clip'}, optional) – Specifies how out-of-bounds indices will behave.
- Returns
out – The returned array has the same type as a.
- Return type
ndarray (Ni…, Nj…, Nk…)
See also
compress()
Take elements using a boolean mask
ndarray.take()
equivalent method
take_along_axis()
Take elements by matching the array and the index arrays
Notes
By eliminating the inner loop in the description above, and using s_ to build simple slice objects, take can be expressed in terms of applying fancy indexing to each 1-d slice:
Ni, Nk = a.shape[:axis], a.shape[axis+1:] for ii in ndindex(Ni):
- for kk in ndindex(Nj):
out[ii + s_[…,] + kk] = a[ii + s_[:,] + kk][indices]
For this reason, it is equivalent to (but faster than) the following use of apply_along_axis:
out = np.apply_along_axis(lambda a_1d: a_1d[indices], axis, a)
Examples
>>> a = [4, 3, 5, 7, 6, 8] >>> indices = [0, 1, 4] >>> np.take(a, indices) array([4, 3, 6])
In this example if a is an ndarray, “fancy” indexing can be used.
>>> a = np.array(a) >>> a[indices] array([4, 3, 6])
If indices is not one dimensional, the output also has these dimensions.
>>> np.take(a, [[0, 1], [2, 3]]) array([[4, 3], [5, 7]])