jax.numpy.nancumsum#

jax.numpy.nancumsum(a, axis=None, dtype=None, out=None)[source]#

Cumulative sum of elements along an axis, ignoring NaN values.

JAX implementation of numpy.nancumsum().

Parameters:
  • a (ArrayLike) – N-dimensional array to be accumulated.

  • axis (int | None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.

  • dtype (DTypeLike | None) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.

  • out (None) – unused by JAX

Returns:

An array containing the accumulated sum along the given axis.

Return type:

Array

See also

Examples

>>> x = jnp.array([[1., 2., jnp.nan],
...                [4., jnp.nan, 6.]])

The standard cumulative sum will propagate NaN values:

>>> jnp.cumsum(x)
Array([ 1.,  3., nan, nan, nan, nan], dtype=float32)

nancumsum() will ignore NaN values, effectively replacing them with zeros:

>>> jnp.nancumsum(x)
Array([ 1.,  3.,  3.,  7.,  7., 13.], dtype=float32)

Cumulative sum along axis 1:

>>> jnp.nancumsum(x, axis=1)
Array([[ 1.,  3.,  3.],
       [ 4.,  4., 10.]], dtype=float32)