jax.numpy.var

Contents

jax.numpy.var#

jax.numpy.var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None)[source]#

Compute the variance along the specified axis.

LAX-backend implementation of numpy.var().

Original docstring below.

Returns the variance of the array elements, a measure of the spread of a distribution. The variance is computed for the flattened array by default, otherwise over the specified axis.

Parameters:
  • a (array_like) – Array containing numbers whose variance is desired. If a is not an array, a conversion is attempted.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which the variance is computed. The default is to compute the variance of the flattened array.

  • dtype (data-type, optional) – Type to use in computing the variance. For arrays of integer type the default is float64; for arrays of float types it is the same as the array type.

  • ddof (int, optional) – “Delta Degrees of Freedom”: the divisor used in the calculation is N - ddof, where N represents the number of elements. By default ddof is zero.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the var method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • where (array_like of bool, optional) – Elements to include in the variance. See ~numpy.ufunc.reduce for details.

  • out (None)

Returns:

variance – If out=None, returns a new array containing the variance; otherwise, a reference to the output array is returned.

Return type:

ndarray, see dtype parameter above