jax.lax.top_k

Contents

jax.lax.top_k#

jax.lax.top_k(operand, k)[source]#

Returns top k values and their indices along the last axis of operand.

Parameters:
  • operand (jax.typing.ArrayLike) – N-dimensional array of non-complex type.

  • k (int) – integer specifying the number of top entries.

Returns:

A tuple (values, indices) where

  • values is an array containing the top k values along the last axis.

  • indices is an array containing the indices corresponding to values.

Return type:

tuple[Array, Array]

Examples

Find the largest three values, and their indices, within an array:

>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10.,  9.,  6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)