# jax.lax.top_k¶

jax.lax.top_k(operand, k)[source]

Returns top k values and their indices along the last axis of operand.

Parameters
Return type