jax.numpy.nancumsumΒΆ

jax.numpy.nancumsum(a, axis=None, dtype=None)ΒΆ
Return the cumulative sum of array elements over a given axis treating Not a

Numbers (NaNs) as zero. The cumulative sum does not change when NaNs are encountered and leading NaNs are replaced by zeros.

LAX-backend implementation of nancumsum(). Original docstring below.

Zeros are returned for slices that are all-NaN or empty.

New in version 1.12.0.

Parameters
  • a (array_like) – Input array.

  • axis (int, optional) – Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.

  • dtype (dtype, optional) – Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.

Returns

nancumsum – A new array holding the result is returned unless out is specified, in which it is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

Return type

ndarray.

See also

numpy.cumsum()

Cumulative sum across array propagating NaNs.

isnan()

Show which elements are NaN.

Examples

>>> np.nancumsum(1)
array([1])
>>> np.nancumsum([1])
array([1])
>>> np.nancumsum([1, np.nan])
array([1.,  1.])
>>> a = np.array([[1, 2], [3, np.nan]])
>>> np.nancumsum(a)
array([1.,  3.,  6.,  6.])
>>> np.nancumsum(a, axis=0)
array([[1.,  2.],
       [4.,  2.]])
>>> np.nancumsum(a, axis=1)
array([[1.,  3.],
       [3.,  3.]])