jax.scipy.special.logsumexp

Contents

jax.scipy.special.logsumexp#

jax.scipy.special.logsumexp(a: Array | ndarray | bool_ | number | bool | int | float | complex, axis: int | Sequence[int] | None = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array[source]#
jax.scipy.special.logsumexp(a: Array | ndarray | bool_ | number | bool | int | float | complex, axis: int | Sequence[int] | None = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
jax.scipy.special.logsumexp(a: Array | ndarray | bool_ | number | bool | int | float | complex, axis: int | Sequence[int] | None = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]

Log-sum-exp reduction.

JAX implementation of scipy.special.logsumexp().

\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]

where the \(j\) indices range over one or more dimensions to be reduced.

Parameters:
  • a – the input array

  • axis – the axis or axes over which to reduce. May be either None, an int, or a tuple of ints.

  • b – scaling factors for \(\mathrm{exp}(a)\). Must be broadcastable to the shape of a.

  • keepdims – If True, the axes that are reduced are left in the output as dimensions of size 1.

  • return_sign – If True, the output will be a (result, sign) pair, where sign is the sign of the sums and result contains the logarithms of their absolute values. If False only result is returned and it will contain NaN values if the sums are negative.

  • where – Elements to include in the reduction.

Returns:

Either an array result or a pair of arrays (result, sign), depending on the value of the return_sign argument.