jax.make_array_from_single_device_arrays#
- jax.make_array_from_single_device_arrays(shape, sharding, arrays)[source]#
Returns a
jax.Array
from a sequence ofjax.Array
s on a single device.jax.Array
on a single device is analogous to aDeviceArray
. You can use this function if you have alreadyjax.device_put
the value on a single device and want to create a global Array. The smallerjax.Array
s should be addressable and belong to the current process.- Parameters:
- Return type:
ArrayImpl
- Returns:
A
jax.Array
from a sequence ofjax.Array
s on a single device.
Example
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> shape = (8, 8) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) >>> inp_data = np.arange(math.prod(shape)).reshape(shape) ... >>> arrays = [ ... jax.device_put(inp_data[index], d) ... for d, index in sharding.addressable_devices_indices_map(shape).items()] ... >>> arr = jax.make_array_from_single_device_arrays(shape, sharding, arrays) >>> arr.addressable_data(0).shape (4, 2)