jax.experimental.global_device_array module#

API#

class jax.experimental.global_device_array.GlobalDeviceArray(global_shape, global_mesh, mesh_axes, device_buffers, _gda_fast_path_args=None, _enable_checks=True)[source]#

A logical array with data sharded across multiple devices and processes.

If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. You can also read about pjit (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html) to learn about Mesh, PartitionSpec and how arrays can be partitioned or replicated.

A GlobalDeviceArray (GDA) can be thought of as a view into a single logical array sharded across processes. The logical array is the β€œglobal” array, and each process has a GlobalDeviceArray object referring to the same global array (similarly to how each process runs a multi-process pmap or pjit). Each process can access the shape, dtype, etc. of the global array via the GDA, pass the GDA into multi-process pjits, and get GDAs as pjit outputs (coming soon: xmap and pmap). However, each process can only directly access the shards of the global array data stored on its local devices.

GDAs can help manage the inputs and outputs of multi-process computations. A GDA keeps track of which shard of the global array belongs to which device, and provides callback-based APIs to materialize the correct shard of the data needed for each local device of each process.

A GDA consists of data shards. Each shard is stored on a different device. There are local shards and global shards. Local shards are those on local devices, and the data is visible to the current process. Global shards are those across all devices (including local devices), and the data isn’t visible if the shard is on a non-local device with respect to the current process. Please see the Shard class to see what information is stored inside that data structure.

Note: to make pjit output GlobalDeviceArrays, set the environment variable JAX_PARALLEL_FUNCTIONS_OUTPUT_GDA=true or add the following to your code: jax.config.update('jax_parallel_functions_output_gda', True)

Parameters
  • global_shape (Tuple[int, ...]) – The global shape of the array.

  • global_mesh (Mesh) – The global mesh representing devices across multiple processes.

  • mesh_axes (PartitionSpec) –

    A sequence with length less than or equal to the rank of the global array (i.e. the length of the global shape). Each element can be:

    • An axis name of global_mesh, indicating that the corresponding global array axis is partitioned across the given device axis of global_mesh.

    • A tuple of axis names of global_mesh. This is like the above option except the global array axis is partitioned across the product of axes named in the tuple.

    • None indicating that the corresponding global array axis is not partitioned.

    For more information, please see: https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec

  • device_buffers (Union[Any, Sequence[DeviceArray]]) – DeviceArrays that are on the local devices of global_mesh.

shape#

Global shape of the array.

dtype#

Dtype of the global array.

ndim#

Number of array dimensions in the global shape.

size#

Number of elements in the global array.

local_shards#

List of Shard on the local devices of the current process. Data is materialized for all local shards.

global_shards#

List of all Shard of the global array. Data isn’t available if a shard is on a non-local device with respect to the current process.

is_fully_replicated#

True if the full array value is present on all devices of the global mesh.

Example

>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> assert jax.device_count() == 8
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> # Logical mesh is (hosts, devices)
>>> assert global_mesh.shape == {'x': 4, 'y': 2}
>>> global_input_shape = (8, 2)
>>> mesh_axes = P('x', 'y')
...
>>> # Dummy example data; in practice we wouldn't necessarily materialize global data
>>> # in a single process.
>>> global_input_data = np.arange(
...   np.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def get_local_data_slice(index):
...  # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4))
...  # This method will be called per-local device from the GDA constructor.
...  return global_input_data[index]
...
>>> gda = GlobalDeviceArray.from_callback(
...        global_input_shape, global_mesh, mesh_axes, get_local_data_slice)
>>> print(gda.shape)
(8, 2)
>>> print(gda.local_shards[0].data)  # Access the data on a single local device
[[0]
 [2]]
>>> print(gda.local_shards[0].data.shape)
(2, 1)
>>> # Numpy-style index into the global array that this data shard corresponds to
>>> print(gda.local_shards[0].index)
(slice(0, 2, None), slice(0, 1, None))

GDAs can also be given as an input to pjit and you can get GDAs as output from pjit:

# Allow pjit to output GDAs
jax.config.update('jax_parallel_functions_output_gda', True)

f = pjit(lambda x: x @ x.T, in_axis_resources=P('x', 'y'), out_axis_resources = P('x', 'y'))
with global_mesh:
  out = f(gda)

# `out` can be passed to another pjit call, out.local_shards can be used to
# export the data to non-jax systems (e.g. for checkpointing or logging), etc.
Parameters
  • _gda_fast_path_args (Optional[_GdaFastPathArgs]) –

  • _enable_checks (bool) –

classmethod from_batched_callback(global_shape, global_mesh, mesh_axes, data_callback)[source]#

Constructs a GlobalDeviceArray via batched data fetched from data_callback.

Like from_callback, except the callback function is called only once to fetch all data local to this process.

Example

>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = P('x')
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
>>> def batched_cb(indices):
...   assert len(indices) == len(global_mesh.local_devices)
...   return [global_input_data[index] for index in indices]
...
>>> gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb)
>>> gda.local_data(0).shape
(2, 2)
Parameters
  • global_shape (Tuple[int, ...]) – The global shape of the array

  • global_mesh (Mesh) – The global mesh representing devices across multiple processes.

  • mesh_axes (PartitionSpec) – See the mesh_axes parameter of GlobalDeviceArray.

  • data_callback (Callable[[Sequence[Tuple[slice, ...]]], Sequence[Union[ndarray, DeviceArray]]]) – Callback that takes a batch of indices into the global array value with length equal to the number of local devices as input and returns the corresponding data for each index. The data can be returned as any array-like objects, e.g. numpy.ndarray

classmethod from_batched_callback_with_devices(global_shape, global_mesh, mesh_axes, data_callback)[source]#

Constructs a GlobalDeviceArray via batched DeviceArrays fetched from data_callback.

Like from_batched_callback, except the callback function is responsible for returning on-device data (e.g. by calling jax.device_put).

Example

>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = P(('x', 'y'))
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
>>> def cb(cb_inp):
...  dbs = []
...  for inp in cb_inp:
...    index, devices = inp
...    array = global_input_data[index]
...    dbs.extend([jax.device_put(array, device) for device in devices])
...  return dbs
...
>>> gda = GlobalDeviceArray.from_batched_callback_with_devices(
...   global_input_shape, global_mesh, mesh_axes, cb)
>>> gda.local_data(0).shape
(1, 2)
Parameters
  • global_shape (Tuple[int, ...]) – The global shape of the array

  • global_mesh (Mesh) – The global mesh representing devices across multiple processes.

  • mesh_axes (PartitionSpec) – See the mesh_axes parameter of GlobalDeviceArray.

  • data_callback (Callable[[Sequence[Tuple[Tuple[slice, ...], Tuple[Device, ...]]]], Sequence[DeviceArray]]) – Callback that takes agets batch of indices into the global array value with length equal to the number of local devices as input and returns the corresponding data for each index. The data must be returned as jax DeviceArrays.

classmethod from_callback(global_shape, global_mesh, mesh_axes, data_callback)[source]#

Constructs a GlobalDeviceArray via data fetched from data_callback.

data_callback is used to fetch the data for each local slice of the returned GlobalDeviceArray.

Example

>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 8)
>>> mesh_axes = P('x', 'y')
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
>>> def cb(index):
...  return global_input_data[index]
...
>>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb)
>>> gda.local_data(0).shape
(4, 2)
Parameters
  • global_shape (Tuple[int, ...]) – The global shape of the array

  • global_mesh (Mesh) – The global mesh representing devices across multiple processes.

  • mesh_axes (PartitionSpec) – See the mesh_axes parameter of GlobalDeviceArray.

  • data_callback (Callable[[Tuple[slice, ...]], Union[ndarray, DeviceArray]]) – Callback that takes indices into the global array value as input and returns the corresponding data of the global array value. The data can be returned as any array-like object, e.g. a numpy.ndarray.

class jax.experimental.global_device_array.Shard(device, index, replica_id, data=None)[source]#

A single data shard of a GlobalDeviceArray.

Parameters
  • device (Device) – Which device this shard resides on.

  • index (Tuple[slice, ...]) – The index into the global array of this shard.

  • replica_id (int) – Integer id indicating which replica of the global array this shard is part of. Always 0 for fully sharded data (i.e. when there’s only 1 replica).

  • data (Optional[DeviceArray]) – The data of this shard. None if device is non-local.