jax.nn.initializers.lecun_uniform

Contents

jax.nn.initializers.lecun_uniform#

jax.nn.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Lecun uniform initializer.

A Lecun uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_in", and distribution="uniform".

Parameters:
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns:

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.56293887,  0.90433645,  0.9119454 ],
       [-0.71479625, -0.7676109 ,  0.12302713]], dtype=float32)