jax.device_put_replicated

jax.device_put_replicated#

jax.device_put_replicated(x, devices)[source]#

Transfer array(s) to each specified device and form Array(s).

Parameters:
  • x (Any) – an array, scalar, or (nested) standard Python container thereof representing the array to be replicated to form the output.

  • devices (Sequence[Device]) – A sequence of Device instances representing the devices to which x will be transferred.

This function is always asynchronous, i.e. returns immediately.

Returns:

An Array or (nested) Python container thereof representing the value of x broadcasted along a new leading axis of size len(devices), with each slice along that new leading axis backed by memory on the device specified by the corresponding entry in devices.

Examples

Passing an array:

>>> import jax
>>> devices = jax.local_devices()
>>> x = jax.numpy.array([1., 2., 3.])
>>> y = jax.device_put_replicated(x, devices)
>>> np.allclose(y, jax.numpy.stack([x for _ in devices]))
True

See also

  • device_put

  • device_put_sharded