jax.nn.initializers.orthogonal

Contents

jax.nn.initializers.orthogonal#

jax.nn.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns uniformly distributed orthogonal matrices.

If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

Parameters:
  • scale (Any) – the upper bound of the uniform distribution.

  • column_axis (int) – the axis that contains the columns that should be orthogonal.

  • dtype (Any) – the default dtype of the weights.

Return type:

Initializer

Returns:

An orthogonal initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
       [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)