jax.nn.initializers package¶

Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

Initializers¶

This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

zeros(key, shape[, dtype])

ones(key, shape[, dtype])

uniform([scale, dtype])

normal([stddev, dtype])

variance_scaling(scale, mode, distribution)

glorot_uniform([in_axis, out_axis, dtype])

glorot_normal([in_axis, out_axis, dtype])

lecun_uniform([in_axis, out_axis, dtype])

lecun_normal([in_axis, out_axis, dtype])

he_uniform([in_axis, out_axis, dtype])

he_normal([in_axis, out_axis, dtype])