jax.lax.cumlogsumexp

Contents

jax.lax.cumlogsumexp#

jax.lax.cumlogsumexp(operand, axis=0, reverse=False)[source]#

Computes a cumulative logsumexp along axis.

Parameters:
Return type:

Array