jax.nn.initializers module

Contents

jax.nn.initializers module#

Common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

Initializers#

This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet.

An initializer is a function that takes three arguments: (key, shape, dtype) and returns an array with dimensions shape and data type dtype. Argument key is a PRNG key (e.g. from jax.random.key()), used to generate random numbers to initialize the array.

constant(value[, dtype])

Builds an initializer that returns arrays full of a constant value.

delta_orthogonal([scale, column_axis, dtype])

Builds an initializer for delta orthogonal kernels.

glorot_normal([in_axis, out_axis, ...])

Builds a Glorot normal initializer (aka Xavier normal initializer).

glorot_uniform([in_axis, out_axis, ...])

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

he_normal([in_axis, out_axis, batch_axis, dtype])

Builds a He normal initializer (aka Kaiming normal initializer).

he_uniform([in_axis, out_axis, batch_axis, ...])

Builds a He uniform initializer (aka Kaiming uniform initializer).

lecun_normal([in_axis, out_axis, ...])

Builds a Lecun normal initializer.

lecun_uniform([in_axis, out_axis, ...])

Builds a Lecun uniform initializer.

normal([stddev, dtype])

Builds an initializer that returns real normally-distributed random arrays.

ones(key, shape[, dtype])

An initializer that returns a constant array full of ones.

orthogonal([scale, column_axis, dtype])

Builds an initializer that returns uniformly distributed orthogonal matrices.

truncated_normal([stddev, dtype, lower, upper])

Builds an initializer that returns truncated-normal random arrays.

uniform([scale, dtype])

Builds an initializer that returns real uniformly-distributed random arrays.

variance_scaling(scale, mode, distribution)

Initializer that adapts its scale to the shape of the weights tensor.

zeros(key, shape[, dtype])

An initializer that returns a constant array full of zeros.