jax.experimental.maps.Mesh#

class jax.experimental.maps.Mesh(devices, axis_names)[source]#

Declare the hardware resources available in the scope of this manager.

In particular, all axis_names become valid resource names inside the managed block and can be used e.g. in the in_axis_resources argument of jax.experimental.pjit.pjit(). Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and pjit tutorial (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html).

If you are compiling in multiple threads, make sure that the with Mesh context manager is inside the function that the threads will execute.

Parameters
  • devices (ndarray) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from jax.devices()).

  • axis_names (Sequence[Any]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example

>>> from jax.experimental.maps import Mesh
>>> from jax.experimental.pjit import pjit
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(inp)
__init__(devices, axis_names)[source]#
Parameters

Methods

__init__(devices, axis_names)

param devices

Attributes

device_ids

empty

is_multi_process

local_devices

local_mesh

shape

size

devices

axis_names