jax.numpy.sort#
- jax.numpy.sort(a, axis=-1, kind=None, order=None, *, stable=True, descending=False)[source]#
Return a sorted copy of an array.
LAX-backend implementation of
numpy.sort()
.Original docstring below.
- Parameters:
a (array_like) – Array to be sorted.
axis (int or None, optional) – Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along the last axis.
kind (deprecated; specify sort algorithm using stable=True or stable=False)
order (not supported)
stable (bool, default=True) – Specify whether to use a stable sort.
descending (bool, default=False) – Specify whether to do a descending sort.
- Returns:
sorted_array – Array of the same type and shape as a.
- Return type:
ndarray