jax.lax.top_k#

jax.lax.top_k(operand, k)[source]#

Returns top k values and their indices along the last axis of operand.

Parameters
Return type

Tuple[Array, Array]