jax.lax.sort#

jax.lax.sort(operand: Sequence[jax.Array], dimension: int = - 1, is_stable: bool = True, num_keys: int = 1) Tuple[jax.Array, ...][source]#
jax.lax.sort(operand: jax.Array, dimension: int = - 1, is_stable: bool = True, num_keys: int = 1) jax.Array

Wraps XLA’s Sort operator.

For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values are sorted to the end of the array. For complex inputs, the sort order is lexicographic over the real and imaginary parts, with the real part primary.

Parameters
  • operand (Union[Array, Sequence[Array]]) – Array or sequence of arrays

  • dimension (int) – integer dimension along which to sort. Default: -1.

  • is_stable (bool) – boolean specifying whether to use a stable sort. Default: True.

  • num_keys (int) – number of operands to treat as sort keys. Default: 1. For num_keys > 1, the sort order will be determined lexicographically using the first num_keys arrays, with the first key being primary. The remaining operands will be returned with the same permutation.

Returns

sorted version of the input or inputs.

Return type

operand