jax.device_put_sharded(shards, devices)[source]#

Transfer array shards to specified devices and form ShardedDeviceArray(s).

  • shards (Sequence[Any]) – A sequence of arrays, scalars, or (nested) standard Python containers thereof representing the shards to be stacked together to form the output. The length of shards must equal the length of devices.

  • devices (Sequence[Device]) – A sequence of Device instances representing the devices to which corresponding shards in shards will be transferred.

This function is always asynchronous, i.e. returns immediately.


A ShardedDeviceArray or (nested) Python container thereof representing the elements of shards stacked together, with each shard backed by physical device memory specified by the corresponding entry in devices.


Passing a list of arrays for shards results in a sharded array containing a stacked version of the inputs:

>>> import jax
>>> devices = jax.local_devices()
>>> x = [jax.numpy.ones(5) for device in devices]
>>> y = jax.device_put_sharded(x, devices)
>>> np.allclose(y, jax.numpy.stack(x))

Passing a list of nested container objects with arrays at the leaves for shards corresponds to stacking the shards at each leaf. This requires all entries in the list to have the same tree structure:

>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))]
>>> y = jax.device_put_sharded(x, devices)
>>> type(y)
<class 'tuple'>
>>> y0 = jax.device_put_sharded([a for a, b in x], devices)
>>> y1 = jax.device_put_sharded([b for a, b in x], devices)
>>> np.allclose(y[0], y0)
>>> np.allclose(y[1], y1)

See also

  • device_put

  • device_put_replicated