jax.numpy.nansumΒΆ

jax.numpy.nansum(a, axis=None, out=None, keepdims=False, **kwargs)ΒΆ
Return the sum of array elements over a given axis treating Not a

Numbers (NaNs) as zero.

LAX-backend implementation of nansum(). Original docstring below.

In NumPy versions <= 1.9.0 Nan is returned for slices that are all-NaN or empty. In later versions zero is returned.

Parameters
  • a (array_like) – Array containing numbers whose sum is desired. If a is not an array, a conversion is attempted.

  • axis ({int, tuple of int, None}, optional) – Axis or axes along which the sum is computed. The default is to compute the sum of the flattened array.

  • out (ndarray, optional) – Alternate output array in which to place the result. The default is None. If provided, it must have the same shape as the expected output, but the type will be cast if necessary. See ufuncs-output-type for more details. The casting of NaN to integer can yield unexpected results.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a.

Returns

nansum – A new array holding the result is returned unless out is specified, in which it is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

Return type

ndarray.

See also

numpy.sum()

Sum across array propagating NaNs.

isnan()

Show which elements are NaN.

isfinite()

Show which elements are not NaN or +/-inf.

Notes

If both positive and negative infinity are present, the sum will be Not A Number (NaN).

Examples

>>> np.nansum(1)
1
>>> np.nansum([1])
1
>>> np.nansum([1, np.nan])
1.0
>>> a = np.array([[1, 1], [1, np.nan]])
>>> np.nansum(a)
3.0
>>> np.nansum(a, axis=0)
array([2.,  1.])
>>> np.nansum([1, np.nan, np.inf])
inf
>>> np.nansum([1, np.nan, np.NINF])
-inf
>>> from numpy.testing import suppress_warnings
>>> with suppress_warnings() as sup:
...     sup.filter(RuntimeWarning)
...     np.nansum([1, np.nan, np.inf, -np.inf]) # both +/- infinity present
nan