jax.lax.argmin# jax.lax.argmin(operand, axis, index_dtype)[source]# Computes the index of the minimum element along axis. Parameters: operand (Union[Array, ndarray, bool_, number, bool, int, float, complex]) axis (int) index_dtype (Union[str, type[Any], dtype, SupportsDType]) Return type: Array