jax.scipy.special.log_ndtr#
- jax.scipy.special.log_ndtr = <jax._src.custom_derivatives.custom_jvp object>[source]#
Log Normal distribution function.
JAX implementation of
scipy.special.log_ndtr
.For details of the Normal distribution function see ndtr.
This function calculates \(\log(\mathrm{ndtr}(x))\) by either calling \(\log(\mathrm{ndtr}(x))\) or using an asymptotic series. Specifically:
For x > upper_segment, use the approximation -ndtr(-x) based on \(\log(1-x) \approx -x, x \ll 1\).
For lower_segment < x <= upper_segment, use the existing ndtr technique and take a log.
For x <= lower_segment, we use the series approximation of erf to compute the log CDF directly.
The lower_segment is set based on the precision of the input:
\[\begin{split}\begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align}\end{split}\]When x < lower_segment, the ndtr asymptotic series approximation is:
\[\begin{split}\begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align}\end{split}\]where \((2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)\) is a double-factorial operator.
- Parameters:
x (ArrayLike) – an array of type float32, float64.
series_order (int) – Positive Python integer. Maximum depth to evaluate the asymptotic expansion. This is the N above.
- Returns:
an array with dtype=x.dtype.
- Raises:
TypeError – if x.dtype is not handled.
TypeError – if series_order is a not Python integer.
ValueError – if series_order is not in [0, 30].
- Return type: