jax.device_put_replicated(x, devices)[source]#

Transfer array(s) to each specified device and form ShardedDeviceArray(s).

  • x (Any) – an array, scalar, or (nested) standard Python container thereof representing the array to be replicated to form the output.

  • devices (Sequence[Device]) – A sequence of Device instances representing the devices to which x will be transferred.

This function is always asynchronous, i.e. returns immediately.


A ShardedDeviceArray or (nested) Python container thereof representing the value of x broadcasted along a new leading axis of size len(devices), with each slice along that new leading axis backed by memory on the device specified by the corresponding entry in devices.


Passing an array:

>>> import jax
>>> devices = jax.local_devices()
>>> x = jax.numpy.array([1., 2., 3.])
>>> y = jax.device_put_replicated(x, devices)
>>> np.allclose(y, jax.numpy.stack([x for _ in devices]))

See also

  • device_put

  • device_put_sharded