jax.numpy.argsort

Contents

jax.numpy.argsort#

jax.numpy.argsort(a, axis=-1, kind=None, order=None, *, stable=True, descending=False)[source]#

Returns the indices that would sort an array.

LAX-backend implementation of numpy.argsort().

Original docstring below.

Perform an indirect sort along the given axis using the algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in sorted order.

Parameters:
  • a (array_like) – Array to sort.

  • axis (int or None, optional) – Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used.

  • kind (deprecated; specify sort algorithm using stable=True or stable=False)

  • order (not supported)

  • stable (bool, default=True) – Specify whether to use a stable sort.

  • descending (bool, default=False) – Specify whether to do a descending sort.

Returns:

index_array – Array of indices that sort a along the specified axis. If a is one-dimensional, a[index_array] yields a sorted a. More generally, np.take_along_axis(a, index_array, axis=axis) always yields the sorted a, irrespective of dimensionality.

Return type:

ndarray, int