jax.lax.sort_key_val

Contents

jax.lax.sort_key_val#

jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[source]#

Sorts keys along dimension and applies the same permutation to values.

Parameters:
  • keys (Array)

  • values (jax.typing.ArrayLike)

  • dimension (int)

  • is_stable (bool)

Return type:

tuple[Array, Array]