jax.lax.sort_key_valΒΆ

jax.lax.sort_key_val(keys, values, dimension=- 1, is_stable=True)[source]ΒΆ

Sorts keys along dimension and applies same permutation to values.

Parameters
  • keys (Any) –

  • values (Any) –

  • dimension (int) –

  • is_stable (bool) –

Return type

Tuple[Any, Any]