jax.numpy.cumsum

Contents

jax.numpy.cumsum#

jax.numpy.cumsum(a, axis=None, dtype=None, out=None)[source]#

Return the cumulative sum of the elements along a given axis.

LAX-backend implementation of numpy.cumsum().

Unlike the numpy counterpart, when dtype is not specified the output dtype will always match the dtype of the input.

Original docstring below.

Parameters:
  • a (array_like) – Input array.

  • axis (int, optional) – Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.

  • dtype (dtype, optional) – Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.

  • out (None)

Returns:

cumsum_along_axis – A new array holding the result is returned unless out is specified, in which case a reference to out is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

Return type:

ndarray.