jax.lax.cumsum# jax.lax.cumsum(operand, axis=0, reverse=False)[source]# Computes a cumulative sum along axis. Parameters operand (Any) – axis (int) – reverse (bool) – Return type Any