jax.lax.cumsumΒΆ

jax.lax.cumsum(operand, axis=0, reverse=False)[source]ΒΆ

Computes a cumulative sum along axis.

Parameters
  • operand (Any) –

  • axis (int) –

  • reverse (bool) –

Return type

Any