jax.numpy.take_along_axis(arr, indices, axis, mode=None)[source]#

Take values from the input array by matching 1d index and data slices.

LAX-backend implementation of numpy.take_along_axis().

Unlike numpy.take_along_axis(), jax.numpy.take_along_axis() takes an optional mode parameter controlling how out-of-bounds indices should be handled. By default, out-of-bounds indices yield invalid values (e.g., NaN). See jax.numpy.ndarray.at for further discussion of out-of-bounds indexing in JAX.

Original docstring below.

This iterates over matching 1d slices oriented along the specified axis in the index and data arrays, and uses the former to look up values in the latter. These slices can be different lengths.

Functions returning an index along an axis, like argsort and argpartition, produce suitable indices for this function.

Added in version 1.15.0.

  • arr (ndarray (Ni..., M, Nk...)) – Source array

  • indices (ndarray (Ni..., J, Nk...)) – Indices to take along each 1d slice of arr. This must match the dimension of arr, but dimensions Ni and Nj only need to broadcast against arr.

  • axis (int) – The axis to take 1d slices along. If axis is None, the input array is treated as if it had first been flattened to 1d, for consistency with sort and argsort.

  • mode (str | GatherScatterMode | None)


out – The indexed result.

Return type:

ndarray (Ni…, J, Nk…)