jax.nn.initializers.uniform

Contents

jax.nn.initializers.uniform#

jax.nn.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns real uniformly-distributed random arrays.

Parameters:
  • scale (Any) – optional; the upper bound of the random distribution.

  • dtype (Any) – optional; the initializer’s default dtype.

Return type:

Initializer

Returns:

An initializer that returns arrays whose values are uniformly distributed in the range [0, scale).

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[7.298188 , 8.691938 , 8.7230015],
       [2.0818567, 1.8662417, 5.5022564]], dtype=float32)