Source code for jax._src.third_party.scipy.betaln

from jax import lax
import jax.numpy as jnp
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import promote_args_inexact

# Note: for mysterious reasons, annotating this leads to very slow mypy runs.
# def algdiv(a: ArrayLike, b: ArrayLike) -> Array:

def algdiv(a, b):
    """
    Compute ``log(gamma(a))/log(gamma(a + b))`` when ``b >= 8``.

    Derived from scipy's implementation of `algdiv`_.

    This differs from the scipy implementation in that it assumes a <= b
    because recomputing ``a, b = jnp.minimum(a, b), jnp.maximum(a, b)`` might
    be expensive and this is only called by ``betaln``.

    .. _algdiv:
        https://github.com/scipy/scipy/blob/c89dfc2b90d993f2a8174e57e0cbc8fbe6f3ee19/scipy/special/cdflib/algdiv.f
    """
    c0 = 0.833333333333333e-01
    c1 = -0.277777777760991e-02
    c2 = 0.793650666825390e-03
    c3 = -0.595202931351870e-03
    c4 = 0.837308034031215e-03
    c5 = -0.165322962780713e-02
    h = a / b
    c = h / (1 + h)
    x = h / (1 + h)
    d = b + (a - 0.5)
    # Set sN = (1 - x**n)/(1 - x)
    x2 = x * x
    s3 = 1.0 + (x + x2)
    s5 = 1.0 + (x + x2 * s3)
    s7 = 1.0 + (x + x2 * s5)
    s9 = 1.0 + (x + x2 * s7)
    s11 = 1.0 + (x + x2 * s9)
    # Set w = del(b) - del(a + b)
    # where del(x) is defined by ln(gamma(x)) = (x - 0.5)*ln(x) - x + 0.5*ln(2*pi) + del(x)
    t = (1.0 / b) ** 2
    w = ((((c5 * s11 * t + c4 * s9) * t + c3 * s7) * t + c2 * s5) * t + c1 * s3) * t + c0
    w = w * (c / b)
    # Combine the results
    u = d * lax.log1p(a / b)
    v = a * (lax.log(b) - 1.0)
    return jnp.where(u <= v, (w - v) - u, (w - u) - v)


[docs] def betaln(a: ArrayLike, b: ArrayLike) -> Array: """Compute the log of the beta function. Derived from scipy's implementation of `betaln`_. This implementation does not handle all branches of the scipy implementation, but is still much more accurate than just doing lgamma(a) + lgamma(b) - lgamma(a + b) when inputs are large (> 1M or so). .. _betaln: https://github.com/scipy/scipy/blob/ef2dee592ba8fb900ff2308b9d1c79e4d6a0ad8b/scipy/special/cdflib/betaln.f """ a, b = promote_args_inexact("betaln", a, b) a, b = jnp.minimum(a, b), jnp.maximum(a, b) small_b = lax.lgamma(a) + (lax.lgamma(b) - lax.lgamma(a + b)) large_b = lax.lgamma(a) + algdiv(a, b) return jnp.where(b < 8, small_b, large_b)