jax.numpy.full

Contents

jax.numpy.full#

jax.numpy.full(shape, fill_value, dtype=None, *, device=None)[source]#

Return a new array of given shape and type, filled with fill_value.

LAX-backend implementation of numpy.full().

Original docstring below.

Parameters:
  • shape (int or sequence of ints) – Shape of the new array, e.g., (2, 3) or 2.

  • fill_value (scalar or array_like) – Fill value.

  • dtype (data-type, optional) –

    The desired data-type for the array The default, None, means

    np.array(fill_value).dtype.

  • device (xc.Device | Sharding | None)

Returns:

out – Array of fill_value with the given shape, dtype, and order.

Return type:

ndarray