jax.lax.argmax#

jax.lax.argmax(operand, axis, index_dtype)[source]#

Computes the index of the maximum element along axis.

Parameters
Return type

Array