jax.lax.argmin#

jax.lax.argmin(operand, axis, index_dtype)[source]#

Computes the index of the minimum element along axis.

Parameters
Return type

Array