jax.numpy.argmaxΒΆ

jax.numpy.argmax(a, axis=None, out=None)[source]ΒΆ

Returns the indices of the maximum values along an axis.

LAX-backend implementation of argmax().

Original docstring below.

Parameters
  • a (array_like) – Input array.

  • axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.

Returns

index_array – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

Return type

ndarray of ints