jax.numpy.average#

jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[source]#

Compute the weighted average along the specified axis.

LAX-backend implementation of numpy.average().

Original docstring below.

Parameters
  • a (array_like) – Array containing data to be averaged. If a is not an array, a conversion is attempted.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which to average a. The default, axis=None, will average over all of the elements of the input array. If axis is negative it counts from the last to the first axis.

  • weights (array_like, optional) –

    An array of weights associated with the values in a. Each value in a contributes to the average according to its associated weight. The weights array can either be 1-D (in which case its length must be the size of a along the given axis) or of the same shape as a. If weights=None, then all data in a are assumed to have a weight equal to one. The 1-D calculation is:

    avg = sum(a * weights) / sum(weights)
    

    The only constraint on weights is that sum(weights) must not be 0.

  • returned (bool, optional) – Default is False. If True, the tuple (average, sum_of_weights) is returned, otherwise only the average is returned. If weights=None, sum_of_weights is equivalent to the number of elements over which the average is taken.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a. Note: keepdims will not work with instances of numpy.matrix or other classes whose methods do not support keepdims.

Returns

retval, [sum_of_weights] – Return the average along the specified axis. When returned is True, return a tuple with the average as the first element and the sum of the weights as the second element. sum_of_weights is of the same type as retval. The result dtype follows a genereal pattern. If weights is None, the result dtype will be that of a , or float64 if a is integral. Otherwise, if weights is not None and a is non- integral, the result type will be the type of lowest precision capable of representing values of both a and weights. If a happens to be integral, the previous rules still applies but the result dtype will at least be float64.

Return type

array_type or double