jax.random.truncated_normal

jax.random.truncated_normal#

jax.random.truncated_normal(key, lower, upper, shape=None, dtype=<class 'float'>)[source]#

Sample truncated standard normal random values with given shape and dtype.

The values are returned according to the probability density function:

\[f(x) \propto e^{-x^2/2}\]

on the domain \(\rm{lower} < x < \rm{upper}\).

Parameters:
  • key (KeyArrayLike) – a PRNG key used as the random key.

  • lower (RealArray) – a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.

  • upper (RealArray) – a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.

  • shape (Shape | NamedShape | None) – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.

  • dtype (DTypeLikeFloat) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Return type:

Array

Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper. Returns values in the open interval (lower, upper).