jax.lax.argmax# jax.lax.argmax(operand, axis, index_dtype)[source]# Computes the index of the maximum element along axis. Parameters operand (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – axis (int) – index_dtype (Union[Any, str, dtype, SupportsDType]) – Return type Array