jax.numpy.average

Contents

jax.numpy.average#

jax.numpy.average(a: ArrayLike, axis: int | Sequence[int] | None = None, weights: ArrayLike | None = None, returned: Literal[False] = False, keepdims: bool = False) Array[source]#
jax.numpy.average(a: ArrayLike, axis: int | Sequence[int] | None = None, weights: ArrayLike | None = None, *, returned: Literal[True], keepdims: bool = False) Array
jax.numpy.average(a: ArrayLike, axis: int | Sequence[int] | None = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) Array | tuple[Array, Array]

Compute the weighted average along the specified axis.

LAX-backend implementation of numpy.average().

Original docstring below.

Parameters:
  • a (array_like) – Array containing data to be averaged. If a is not an array, a conversion is attempted.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which to average a. The default, axis=None, will average over all of the elements of the input array. If axis is negative it counts from the last to the first axis.

  • weights (array_like, optional) –

    An array of weights associated with the values in a. Each value in a contributes to the average according to its associated weight. The array of weights must be the same shape as a if no axis is specified, otherwise the weights must have dimensions and shape consistent with a along the specified axis. If weights=None, then all data in a are assumed to have a weight equal to one. The calculation is:

    avg = sum(a * weights) / sum(weights)
    

    where the sum is over all included elements. The only constraint on the values of weights is that sum(weights) must not be 0.

  • returned (bool, optional) – Default is False. If True, the tuple (average, sum_of_weights) is returned, otherwise only the average is returned. If weights=None, sum_of_weights is equivalent to the number of elements over which the average is taken.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a. Note: keepdims will not work with instances of numpy.matrix or other classes whose methods do not support keepdims.

Returns:

retval, [sum_of_weights] – Return the average along the specified axis. When returned is True, return a tuple with the average as the first element and the sum of the weights as the second element. sum_of_weights is of the same type as retval. The result dtype follows a general pattern. If weights is None, the result dtype will be that of a , or float64 if a is integral. Otherwise, if weights is not None and a is non- integral, the result type will be the type of lowest precision capable of representing values of both a and weights. If a happens to be integral, the previous rules still applies but the result dtype will at least be float64.

Return type:

array_type or double