jax.nn.initializers.he_uniform

Contents

jax.nn.initializers.he_uniform#

jax.nn.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)[source]#

Builds a He uniform initializer (aka Kaiming uniform initializer).

A He uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="uniform".

Parameters:
  • in_axis (int | Sequence[int]) – axis or sequence of axes of the input dimension in the weights array.

  • out_axis (int | Sequence[int]) – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis (Sequence[int]) – axis or sequence of axes in the weight array that should be ignored.

  • dtype (Any) – the dtype of the weights.

Returns:

An initializer.

Return type:

Initializer

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)