jax.lax.top_k

Contents

jax.lax.top_k#

jax.lax.top_k(operand, k)[source]#

Returns top k values and their indices along the last axis of operand.

Parameters:
  • operand (jax.typing.ArrayLike) – N-dimensional array of non-complex type.

  • k (int) – integer specifying the number of top entries.

Returns:

array containing the top k values along the last axis. indices: array containing the indices corresponding to values.

Return type:

values

See also: - jax.lax.approx_max_k() - jax.lax.approx_min_k()