jax.lax.sort¶

jax.lax.sort(operand, dimension=-1, is_stable=True, num_keys=1)[source]¶

Wraps XLA’s Sort operator.

Parameters
  • operand (Union[Any, Sequence[Any]]) – Array or sequence of arrays

  • dimension (int) – integer dimension along which to sort. Default: -1.

  • is_stable (bool) – boolean specifying whether to use a stable sort. Default: True.

  • num_keys (int) – number of operands to treat as sort keys. Default: 1. For num_keys > 1, the sort order will be determined lexicographically using the first num_keys arrays, with the first key being primary. The remaining operands will be returned with the same permutation.

Returns

sorted version of the input or inputs.

Return type

operand