jax.numpy.full_like#

jax.numpy.full_like(a, fill_value, dtype=None, shape=None, *, device=None)[source]#

Create an array full of a specified value with the same shape and dtype as an array.

JAX implementation of numpy.full_like().

Parameters:
  • a (ArrayLike | DuckTypedArray) – Array-like object with shape and dtype attributes.

  • fill_value (ArrayLike) – scalar or array with which to fill the created array.

  • shape (Any | None) – optionally override the shape of the created array.

  • dtype (DTypeLike | None | None) – optionally override the dtype of the created array.

  • device (xc.Device | Sharding | None | None) – (optional) Device or Sharding to which the created array will be committed.

Returns:

Array of the specified shape and dtype, on the specified device if specified.

Return type:

Array

Examples

>>> x = jnp.arange(4.0)
>>> jnp.full_like(x, 2)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full_like(x, 0, shape=(2, 3))
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

fill_value may also be an array that is broadcast to the specified shape:

>>> x = jnp.arange(6).reshape(2, 3)
>>> jnp.full_like(x, fill_value=jnp.array([[1], [2]]))
Array([[1, 1, 1],
       [2, 2, 2]], dtype=int32)