jax.numpy.average(a, axis=None, weights=None, returned=False)[source]#

Compute the weighted average along the specified axis.

LAX-backend implementation of numpy.average().

Original docstring below.

  • a (array_like) – Array containing data to be averaged. If a is not an array, a conversion is attempted.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which to average a. The default, axis=None, will average over all of the elements of the input array. If axis is negative it counts from the last to the first axis.

  • weights (array_like, optional) –

    An array of weights associated with the values in a. Each value in a contributes to the average according to its associated weight. The weights array can either be 1-D (in which case its length must be the size of a along the given axis) or of the same shape as a. If weights=None, then all data in a are assumed to have a weight equal to one. The 1-D calculation is:

    avg = sum(a * weights) / sum(weights)

    The only constraint on weights is that sum(weights) must not be 0.

  • returned (bool, optional) – Default is False. If True, the tuple (average, sum_of_weights) is returned, otherwise only the average is returned. If weights=None, sum_of_weights is equivalent to the number of elements over which the average is taken.


retval, [sum_of_weights] – Return the average along the specified axis. When returned is True, return a tuple with the average as the first element and the sum of the weights as the second element. sum_of_weights is of the same type as retval. The result dtype follows a genereal pattern. If weights is None, the result dtype will be that of a , or float64 if a is integral. Otherwise, if weights is not None and a is non- integral, the result type will be the type of lowest precision capable of representing values of both a and weights. If a happens to be integral, the previous rules still applies but the result dtype will at least be float64.

Return type

array_type or double