jax.scipy.special.log_ndtr#

jax.scipy.special.log_ndtr = <jax._src.custom_derivatives.custom_jvp object>[source]#

Log Normal distribution function.

JAX implementation of scipy.special.log_ndtr.

For details of the Normal distribution function see ndtr.

This function calculates \(\log(\mathrm{ndtr}(x))\) by either calling \(\log(\mathrm{ndtr}(x))\) or using an asymptotic series. Specifically:

  • For x > upper_segment, use the approximation -ndtr(-x) based on \(\log(1-x) \approx -x, x \ll 1\).

  • For lower_segment < x <= upper_segment, use the existing ndtr technique and take a log.

  • For x <= lower_segment, we use the series approximation of erf to compute the log CDF directly.

The lower_segment is set based on the precision of the input:

\[\begin{split}\begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align}\end{split}\]

When x < lower_segment, the ndtr asymptotic series approximation is:

\[\begin{split}\begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align}\end{split}\]

where \((2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)\) is a double-factorial operator.

Parameters:
  • x (ArrayLike) – an array of type float32, float64.

  • series_order (int) – Positive Python integer. Maximum depth to evaluate the asymptotic expansion. This is the N above.

Returns:

an array with dtype=x.dtype.

Raises:
  • TypeError – if x.dtype is not handled.

  • TypeError – if series_order is a not Python integer.

  • ValueError – if series_order is not in [0, 30].

Return type:

Array