jax.device_put_replicated#
- jax.device_put_replicated(x, devices)[source]#
Transfer array(s) to each specified device and form Array(s).
- Parameters
This function is always asynchronous, i.e. returns immediately.
- Returns
An Array or (nested) Python container thereof representing the value of
x
broadcasted along a new leading axis of sizelen(devices)
, with each slice along that new leading axis backed by memory on the device specified by the corresponding entry indevices
.
Examples
Passing an array:
>>> import jax >>> devices = jax.local_devices() >>> x = jax.numpy.array([1., 2., 3.]) >>> y = jax.device_put_replicated(x, devices) >>> np.allclose(y, jax.numpy.stack([x for _ in devices])) True
See also
device_put
device_put_sharded