jax.nn.initializers.variance_scaling

jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)[source]