jax.nn.initializers.ones

jax.nn.initializers.ones(key, shape, dtype=<class 'jax.numpy.lax_numpy.float32'>)[source]