jax.numpy.lexsort

Contents

jax.numpy.lexsort#

jax.numpy.lexsort(keys, axis=-1)[source]#

Perform an indirect stable sort using a sequence of keys.

LAX-backend implementation of numpy.lexsort().

Original docstring below.

Given multiple sorting keys, which can be interpreted as columns in a spreadsheet, lexsort returns an array of integer indices that describes the sort order by multiple columns. The last key in the sequence is used for the primary sort order, the second-to-last key for the secondary sort order, and so on. The keys argument must be a sequence of objects that can be converted to arrays of the same shape. If a 2D array is provided for the keys argument, its rows are interpreted as the sorting keys and sorting is according to the last row, second last row etc.

Parameters:
  • keys ((k, N) array or tuple containing k (N,)-shaped sequences) – The k different “columns” to be sorted. The last column (or row if keys is a 2D array) is the primary sort key.

  • axis (int, optional) – Axis to be indirectly sorted. By default, sort over the last axis.

Returns:

indices – Array of indices that sort the keys along the specified axis.

Return type:

(N,) ndarray of ints