jax.numpy.cumulative_sum#
- jax.numpy.cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False)[source]#
Cumulative sum along the axis of an array.
JAX implementation of
numpy.cumulative_sum()
.- Parameters:
x (ArrayLike) – N-dimensional array
axis (int | None | None) – integer axis along which to accumulate. If
x
is one-dimensional, this argument is optional and defaults to zero.dtype (DTypeLike | None | None) – optional dtype of the output.
include_initial (bool) – if True, then include the initial value in the cumulative sum. Default is False.
- Returns:
An array containing the accumulated values.
- Return type:
See also
jax.numpy.cumsum()
: alternative API for cumulative sum.jax.numpy.nancumsum()
: cumulative sum while ignoring NaN values.jax.numpy.add.accumulate()
: cumulative sum via the ufunc API.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumulative_sum(x, axis=1) Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32) >>> jnp.cumulative_sum(x, axis=1, include_initial=True) Array([[ 0, 1, 3, 6], [ 0, 4, 9, 15]], dtype=int32)