# jax.numpy.max¶

jax.numpy.max(a, axis=None, dtype=None, out=None, keepdims=False)

Return the maximum of an array or maximum along an axis.

LAX-backend implementation of amax(). Original docstring below.

Parameters
• a (array_like) – Input data.

• axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.

• out (ndarray, optional) – Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output. See ufuncs-output-type for more details.

• keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

Returns

amax – Maximum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension a.ndim - 1.

Return type

ndarray or scalar

amin()

The minimum value of an array along a given axis, propagating any NaNs.

nanmax()

The maximum value of an array along a given axis, ignoring any NaNs.

maximum()

Element-wise maximum of two arrays, propagating any NaNs.

fmax()

Element-wise maximum of two arrays, ignoring any NaNs.

argmax()

Return the indices of the maximum values.

Notes

NaN values are propagated, that is if at least one item is NaN, the corresponding max value will be NaN as well. To ignore NaN values (MATLAB behavior), please use nanmax.

Don’t use amax for element-wise comparison of 2 arrays; when a.shape[0] is 2, maximum(a[0], a[1]) is faster than amax(a, axis=0).

Examples

>>> a = np.arange(4).reshape((2,2))
>>> a
array([[0, 1],
[2, 3]])
>>> np.amax(a)           # Maximum of the flattened array
3
>>> np.amax(a, axis=0)   # Maxima along the first axis
array([2, 3])
>>> np.amax(a, axis=1)   # Maxima along the second axis
array([1, 3])
>>> np.amax(a, where=[False, True], initial=-1, axis=0)
array([-1,  3])
>>> b = np.arange(5, dtype=float)
>>> b[2] = np.NaN
>>> np.amax(b)
nan
>>> np.amax(b, where=~np.isnan(b), initial=-1)
4.0
>>> np.nanmax(b)
4.0


You can use an initial value to compute the maximum of an empty slice, or to initialize it to a different value:

>>> np.max([[-50], [10]], axis=-1, initial=0)
array([ 0, 10])


Notice that the initial value is used as one of the elements for which the maximum is determined, unlike for the default argument Python’s max function, which is only used for empty iterables.

>>> np.max([5], initial=6)
6
>>> max([5], default=6)
5