jax.numpy.cumulative_sum# jax.numpy.cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False)[source]# Parameters: x (ArrayLike) axis (int | None) dtype (DTypeLike | None) include_initial (bool) Return type: Array