jax.nn.initializers.variance_scaling

jax.nn.initializers.variance_scaling#

jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)[source]#

Initializer that adapts its scale to the shape of the weights tensor.

With distribution="truncated_normal" or distribution="normal", samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is:

  • the number of input units in the weights tensor, if mode="fan_in",

  • the number of output units, if mode="fan_out", or

  • the average of the numbers of input and output units, if mode="fan_avg".

This initializer can be configured with in_axis, out_axis, and batch_axis to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).

With distribution="truncated_normal", the absolute values of the samples are truncated at 2 standard deviations before scaling.

With distribution="uniform", samples are drawn from:

  • a uniform interval, if dtype is real, or

  • a uniform disk, if dtype is complex,

with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.

Parameters:
  • scale (Any) – scaling factor (positive float).

  • mode (Literal['fan_in'] | ~typing.Literal['fan_out'] | ~typing.Literal['fan_avg']) – one of "fan_in", "fan_out", and "fan_avg".

  • distribution (Literal['truncated_normal'] | ~typing.Literal['normal'] | ~typing.Literal['uniform']) – random distribution to use. One of "truncated_normal", "normal" and "uniform".

  • in_axis (int | Sequence[int]) – axis or sequence of axes of the input dimension in the weights array.

  • out_axis (int | Sequence[int]) – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis (Sequence[int]) – axis or sequence of axes in the weight array that should be ignored.

  • dtype (Any) – the dtype of the weights.

Return type:

Initializer