jax.nn.initializers.lecun_uniform

jax.nn.initializers.lecun_uniform(in_axis=-2, out_axis=-1, dtype=<class 'jax.numpy.lax_numpy.float32'>)