# jax.numpy.cumsumΒΆ

jax.numpy.cumsum(a, axis=None, dtype=None)ΒΆ

Return the cumulative sum of the elements along a given axis.

LAX-backend implementation of cumsum(). Original docstring below.

Parameters
• a (array_like) β Input array.

• axis (int, optional) β Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.

• dtype (dtype, optional) β Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.

Returns

cumsum_along_axis β A new array holding the result is returned unless out is specified, in which case a reference to out is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

Return type

ndarray.

sum()

Sum array elements.

trapz()

Integration of array values using the composite trapezoidal rule.

diff()

Calculate the n-th discrete difference along given axis.

Notes

Arithmetic is modular when using integer types, and no error is raised on overflow.

Examples

>>> a = np.array([[1,2,3], [4,5,6]])
>>> a
array([[1, 2, 3],
[4, 5, 6]])
>>> np.cumsum(a)
array([ 1,  3,  6, 10, 15, 21])
>>> np.cumsum(a, dtype=float)     # specifies type of output value(s)
array([  1.,   3.,   6.,  10.,  15.,  21.])

>>> np.cumsum(a,axis=0)      # sum over rows for each of the 3 columns
array([[1, 2, 3],
[5, 7, 9]])
>>> np.cumsum(a,axis=1)      # sum over columns for each of the 2 rows
array([[ 1,  3,  6],
[ 4,  9, 15]])