jax.numpy.max¶

jax.numpy.max(a, axis=None, out=None, keepdims=None, initial=None, where=None)[source]¶

Return the maximum of an array or maximum along an axis.

LAX-backend implementation of amax().

Original docstring below.

Parameters
  • a (array_like) – Input data.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the amax method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – The minimum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to compare for the maximum. See ~numpy.ufunc.reduce for details.

Returns

amax – Maximum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension a.ndim - 1.

Return type

ndarray or scalar