jax.make_array_from_single_device_arrays#

jax.make_array_from_single_device_arrays(shape, sharding, arrays)[source]#

Returns a jax.Array from a sequence of jax.Arrays on a single device.

You can use this function if you have already jax.device_put the value on a single device and want to create a global Array. The smaller jax.Arrays should be addressable and belong to the current process.

Parameters
  • shape (tuple[int, ...]) – Shape of the jax.Array.

  • sharding (Sharding) – A Sharding instance which describes how the jax.Array is laid out across devices.

  • arrays (Sequence[Array]) – Sequence of jax.Arrays that are on a single device.

Return type

ArrayImpl

Returns

A jax.Array from a sequence of jax.Arrays on a single device.

Example

>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> global_shape = (8, 8)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
>>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape)
...
>>> arrays = [
...     jax.device_put(inp_data[index], d)
...     for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
...
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
>>> arr.addressable_data(0).shape
(4, 2)

In multi-process case, if the input is process local and data parallel i.e. each process receives a different part of the data, then you can use make_array_from_single_device_arrays to create a global jax.Array

>>> local_shape = (8, 2)
>>> global_shape = (jax.process_count() * local_shape[0], ) + local_shape[1:]
>>> local_array = np.arange(math.prod(local_shape)).reshape(local_shape)
>>> arrays = jax.device_put(
...   np.split(local_array, len(global_mesh.local_devices), axis = 0), global_mesh.local_devices)
>>> sharding = jax.sharding.NamedSharding(global_mesh, P(('x', 'y'), ))
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
>>> arr.addressable_data(0).shape
(1, 2)