jax.lax.sort_key_val# jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[source]# Sorts keys along dimension and applies the same permutation to values. Parameters: keys (Array) values (ArrayLike) dimension (int) is_stable (bool) Return type: tuple[Array, Array]