jax.lax.argmaxΒΆ

jax.lax.argmax(operand, axis, index_dtype)[source]ΒΆ

Computes the index of the maximum element along axis.

Parameters
  • operand (Any) –

  • axis (int) –

  • index_dtype (Any) –

Return type

Tuple[Any, Any]