jax.numpy.mean

Contents

jax.numpy.mean#

jax.numpy.mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=None)[source]#

Compute the arithmetic mean along the specified axis.

LAX-backend implementation of numpy.mean().

Original docstring below.

Returns the average of the array elements. The average is taken over the flattened array by default, otherwise over the specified axis. float64 intermediate and return values are used for integer inputs.

Parameters:
  • a (array_like) – Array containing numbers whose mean is desired. If a is not an array, a conversion is attempted.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.

  • dtype (data-type, optional) – Type to use in computing the mean. For integer inputs, the default is float64; for floating point inputs, it is the same as the input dtype.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the mean method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • where (array_like of bool, optional) – Elements to include in the mean. See ~numpy.ufunc.reduce for details.

  • out (None)

Returns:

m – If out=None, returns a new array containing the mean values, otherwise a reference to the output array is returned.

Return type:

ndarray, see dtype parameter above