jax.device_put_replicated

jax.device_put_replicated#

jax.device_put_replicated(x, devices)[source]#

Transfer array(s) to each specified device and form Array(s).

Parameters:
  • x (Any) – an array, scalar, or (nested) standard Python container thereof representing the array to be replicated to form the output.

  • devices (Sequence[Device]) – A sequence of Device instances representing the devices to which x will be transferred.

This function is always asynchronous, i.e. returns immediately.

Returns:

An Array or (nested) Python container thereof representing the value of x broadcasted along a new leading axis of size len(devices), with each slice along that new leading axis backed by memory on the device specified by the corresponding entry in devices.

Parameters:

Examples

Passing an array:

>>> import jax
>>> devices = jax.local_devices()
>>> x = jax.numpy.array([1., 2., 3.])
>>> y = jax.device_put_replicated(x, devices)
>>> np.allclose(y, jax.numpy.stack([x for _ in devices]))
True

See also

  • device_put

  • device_put_sharded