jax.numpy.sort¶

jax.numpy.sort(a, axis=- 1, kind='quicksort', order=None)[source]¶

Return a sorted copy of an array.

LAX-backend implementation of sort().

Original docstring below.

Parameters
  • a (array_like) – Array to be sorted.

  • axis (int or None, optional) – Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along the last axis.

  • kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) –

    Sorting algorithm. The default is ‘quicksort’. Note that both ‘stable’ and ‘mergesort’ use timsort or radix sort under the covers and, in general, the actual implementation will vary with data type. The ‘mergesort’ option is retained for backwards compatibility.

    Changed in version 1.15.0.: The ‘stable’ option was added.

  • order (str or list of str, optional) – When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.

Returns

sorted_array – Array of the same type and shape as a.

Return type

ndarray