jax.lax.top_k#
- jax.lax.top_k(operand, k)[source]#
Returns top
k
values and their indices along the last axis ofoperand
.- Parameters:
operand (jax.typing.ArrayLike) – N-dimensional array of non-complex type.
k (int) – integer specifying the number of top entries.
- Returns:
array containing the top k values along the last axis. indices: array containing the indices corresponding to values.
- Return type:
values
See also: -
jax.lax.approx_max_k()
-jax.lax.approx_min_k()