jax.nn.initializers.truncated_normal

jax.nn.initializers.truncated_normal#

jax.nn.initializers.truncated_normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>, lower=-2.0, upper=2.0)#

Builds an initializer that returns truncated-normal random arrays.

Parameters:
  • stddev (Any) – optional; the standard deviation of the untruncated distribution. Note that this function does not apply the stddev correction as is done in the variancescaling initializers, and users are expected to apply this correction themselves via the stddev arg if they wish to employ it.

  • dtype (Any) – optional; the initializer’s default dtype.

  • lower (Any) – Float representing the lower bound for truncation. Applied before the output is multiplied by the stddev.

  • upper (Any) – Float representing the upper bound for truncation. Applied before the output is multiplied by the stddev.

Return type:

Initializer

Returns:

An initializer that returns arrays whose values follow the truncated normal distribution with mean 0 and standard deviation stddev, and range \(\rm{lower * stddev} < x < \rm{upper * stddev}\).

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 2.9047365,  5.2338114,  5.29852  ],
       [-3.836303 , -4.192359 ,  0.6022964]], dtype=float32)