jax.numpy.cumulative_sum

Contents

jax.numpy.cumulative_sum#

jax.numpy.cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False)[source]#
Parameters:
  • x (ArrayLike)

  • axis (int | None)

  • dtype (DTypeLike | None)

  • include_initial (bool)

Return type:

Array