jax.numpy.sort

Contents

jax.numpy.sort#

jax.numpy.sort(a, axis=-1, kind=None, order=None, *, stable=True, descending=False)[source]#

Return a sorted copy of an array.

LAX-backend implementation of numpy.sort().

Original docstring below.

Parameters:
  • a (array_like) – Array to be sorted.

  • axis (int or None, optional) – Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along the last axis.

  • kind (deprecated; specify sort algorithm using stable=True or stable=False) –

  • order (not supported) –

  • stable (bool, default=True) – Specify whether to use a stable sort.

  • descending (bool, default=False) – Specify whether to do a descending sort.

Returns:

sorted_array – Array of the same type and shape as a.

Return type:

ndarray