jax.numpy.argmax

Contents

jax.numpy.argmax#

jax.numpy.argmax(a, axis=None, out=None, keepdims=None)[source]#

Returns the indices of the maximum values along an axis.

LAX-backend implementation of numpy.argmax().

Original docstring below.

Parameters:
  • a (array_like) – Input array.

  • axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.

  • out (None)

Returns:

index_array – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.

Return type:

ndarray of ints