jax.numpy.max(a, axis=None, out=None, keepdims=False, initial=None, where=None)[source]#

Return the maximum of an array or maximum along an axis.

LAX-backend implementation of numpy.max().

Original docstring below.

  • a (array_like) – Input data.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the max method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – The minimum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to compare for the maximum. See ~numpy.ufunc.reduce for details.

  • out (None)


max – Maximum of a. If axis is None, the result is a scalar value. If axis is an int, the result is an array of dimension a.ndim - 1. If axis is a tuple, the result is an array of dimension a.ndim - len(axis).

Return type:

ndarray or scalar