jax.scipy.stats.truncnorm.logpdf

Contents

jax.scipy.stats.truncnorm.logpdf#

jax.scipy.stats.truncnorm.logpdf(x, a, b, loc=0, scale=1)[source]#

Truncated normal log probability distribution function.

JAX implementation of scipy.stats.truncnorm logpdf.

The truncated normal probability distribution is given by

\[\begin{split}f(x, a, b) = \begin{cases} \frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\ 0 & \mathrm{otherwise} \end{cases}\end{split}\]

where \(a\) and \(b\) are effectively specified in number of standard deviations from zero. JAX uses the scipy nomenclature of loc for the centroid and scale for the standard deviation.

Parameters:
  • x – arraylike, value at which to evaluate the PDF

  • a – arraylike, distribution shape parameter

  • b – arraylike, distribution shape parameter

  • loc – arraylike, distribution offset parameter

  • scale – arraylike, distribution scale parameter

Returns:

array of logpdf values.