- jax.device_put_replicated(x, devices)#
Transfer array(s) to each specified device and form ShardedDeviceArray(s).
This function is always asynchronous, i.e. returns immediately.
A ShardedDeviceArray or (nested) Python container thereof representing the value of
xbroadcasted along a new leading axis of size
len(devices), with each slice along that new leading axis backed by memory on the device specified by the corresponding entry in
Passing an array:
>>> import jax >>> devices = jax.local_devices() >>> x = jax.numpy.array([1., 2., 3.]) >>> y = jax.device_put_replicated(x, devices) >>> np.allclose(y, jax.numpy.stack([x for _ in devices])) True