jax.nn.initializers.constant

Contents

jax.nn.initializers.constant#

jax.nn.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns arrays full of a constant value.

Parameters:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
:rtype: :sphinx_autodoc_typehints_type:`\:py\:class\:\`\~jax.nn.initializers.Initializer\``
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
       [-7., -7., -7.]], dtype=float32)