jax.device_put_sharded#
- jax.device_put_sharded(shards, devices)[source]#
Transfer array shards to specified devices and form Array(s).
- Parameters
shards (
Sequence
[Any
]) – A sequence of arrays, scalars, or (nested) standard Python containers thereof representing the shards to be stacked together to form the output. The length ofshards
must equal the length ofdevices
.devices (
Sequence
[Device
]) – A sequence ofDevice
instances representing the devices to which corresponding shards inshards
will be transferred.
This function is always asynchronous, i.e. returns immediately.
- Returns
A Array or (nested) Python container thereof representing the elements of
shards
stacked together, with each shard backed by physical device memory specified by the corresponding entry indevices
.
Examples
Passing a list of arrays for
shards
results in a sharded array containing a stacked version of the inputs:>>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] >>> y = jax.device_put_sharded(x, devices) >>> np.allclose(y, jax.numpy.stack(x)) True
Passing a list of nested container objects with arrays at the leaves for
shards
corresponds to stacking the shards at each leaf. This requires all entries in the list to have the same tree structure:>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] >>> y = jax.device_put_sharded(x, devices) >>> type(y) <class 'tuple'> >>> y0 = jax.device_put_sharded([a for a, b in x], devices) >>> y1 = jax.device_put_sharded([b for a, b in x], devices) >>> np.allclose(y[0], y0) True >>> np.allclose(y[1], y1) True
See also
device_put
device_put_replicated