jax.nn.initializers package¶

Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

Initializers¶

This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

zeros(key, shape[, dtype])

ones(key, shape[, dtype])

uniform([scale, dtype])

normal([stddev, dtype])

variance_scaling(scale, mode, distribution)

Initializer capable of adapting its scale to the shape of the weights tensor.

glorot_uniform([in_axis, out_axis, dtype])

Initializer capable of adapting its scale to the shape of the weights tensor.

glorot_normal([in_axis, out_axis, dtype])

Initializer capable of adapting its scale to the shape of the weights tensor.

lecun_uniform([in_axis, out_axis, dtype])

Initializer capable of adapting its scale to the shape of the weights tensor.

lecun_normal([in_axis, out_axis, dtype])

Initializer capable of adapting its scale to the shape of the weights tensor.

he_uniform([in_axis, out_axis, dtype])

Initializer capable of adapting its scale to the shape of the weights tensor.

he_normal([in_axis, out_axis, dtype])

Initializer capable of adapting its scale to the shape of the weights tensor.