jax.lax.argmin# jax.lax.argmin(operand, axis, index_dtype)[source]# Computes the index of the minimum element along axis. Parameters: operand (jax.typing.ArrayLike) axis (int) index_dtype (jax.typing.DTypeLike) Return type: Array