- jax.device_put_sharded(shards, devices)#
Transfer array shards to specified devices and form Array(s).
Any]) – A sequence of arrays, scalars, or (nested) standard Python containers thereof representing the shards to be stacked together to form the output. The length of
shardsmust equal the length of
This function is always asynchronous, i.e. returns immediately.
A Array or (nested) Python container thereof representing the elements of
shardsstacked together, with each shard backed by physical device memory specified by the corresponding entry in
Passing a list of arrays for
shardsresults in a sharded array containing a stacked version of the inputs:
>>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] >>> y = jax.device_put_sharded(x, devices) >>> np.allclose(y, jax.numpy.stack(x)) True
Passing a list of nested container objects with arrays at the leaves for
shardscorresponds to stacking the shards at each leaf. This requires all entries in the list to have the same tree structure:
>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] >>> y = jax.device_put_sharded(x, devices) >>> type(y) <class 'tuple'> >>> y0 = jax.device_put_sharded([a for a, b in x], devices) >>> y1 = jax.device_put_sharded([b for a, b in x], devices) >>> np.allclose(y, y0) True >>> np.allclose(y, y1) True