jax.lax.top_k

jax.lax.top_k(operand, k)[source]

Returns top k values and their indices along the last axis of operand.

Parameters
  • operand (Any) –

  • k (int) –

Return type

Tuple[Any, Any]