jax.sharding module#

Classes#

class jax.sharding.Sharding#

Abstract Sharding interface which describes how a jax.Array is laid out across devices.

property addressable_devices: Set[Device]#

A set of devices that are addressable by the current process.

addressable_devices_indices_map(global_shape)[source]#

A mapping from addressable device to the slice of global data it contains.

addressable_devices_indices_map contains that part of device_indices_map that applies to the addressable devices.

Parameters:

global_shape (Tuple[int, ...]) –

Return type:

Mapping[Device, Optional[Tuple[slice, ...]]]

property device_set: Set[Device]#

A set of global devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

devices_indices_map(global_shape)[source]#

A global mapping from device to the slice of the global data it contains.

The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.

Parameters:

global_shape (Tuple[int, ...]) –

Return type:

Mapping[Device, Optional[Tuple[slice, ...]]]

is_equivalent_to(other, ndim)[source]#

Returns True if two shardings put the same logical array (sharded/unsharded) on the same device(s).

For example, every XLACompatibleSharding lowers to GSPMDSharding which is a general representation. So jax.sharding.NamedSharding is equivalent to jax.sharding.PositionalSharding if both of them lower to the same GSPMDSharding.

Parameters:
Return type:

bool

property is_fully_addressable: bool#

True if the current process can address all of the devices in device_set.

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated from the global shape (it takes as an input) and the properties of the sharding.

Parameters:

global_shape (Tuple[int, ...]) –

Return type:

Tuple[int, ...]

class jax.sharding.XLACompatibleSharding#

Bases: Sharding

A Sharding that describes shardings expressible to XLA.

Any Sharding that is a subclass of XLACompatibleSharding will work with all JAX APIs and transformations that use XLA.

devices_indices_map(global_shape)[source]#

A global mapping from device to the slice of the global data it contains.

The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.

Parameters:

global_shape (Tuple[int, ...]) –

Return type:

Mapping[Device, Tuple[slice, ...]]

is_equivalent_to(other, ndim)[source]#

Returns True if two shardings put the same logical array (sharded/unsharded) on the same device(s).

For example, every XLACompatibleSharding lowers to GSPMDSharding which is a general representation. So jax.sharding.NamedSharding is equivalent to jax.sharding.PositionalSharding if both of them lower to the same GSPMDSharding.

Parameters:
Return type:

bool

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated from the global shape (it takes as an input) and the properties of the sharding.

Parameters:

global_shape (Tuple[int, ...]) –

Return type:

Tuple[int, ...]

class jax.sharding.NamedSharding#

Bases: XLACompatibleSharding

NamedSharding is a way to express Shardings using named axes.

Mesh and PartitionSpec can be used to express a Sharding with a name.

Mesh is a NumPy array of JAX devices in a multi-dimensional grid, where each axis of the mesh has a name, e.g. ‘x’ or ‘y’. Another name for Mesh is “logical mesh”.

PartitionSpec is a tuple, whose elements can be a None, a mesh axis or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, PartitionSpec(‘x’, ‘y’) is a PartitionSpec where the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh.

The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) goes into more details and has diagrams to help explain the concept about Mesh and PartitionSpec.

Parameters:
  • mesh – A jax.sharding.Mesh object.

  • spec – A jax.sharding.PartitionSpec object.

Example

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property device_set: Set[Device]#

A set of global devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

class jax.sharding.SingleDeviceSharding#

Bases: XLACompatibleSharding

A subclass of XLACompatibleSharding that places its data on a single device.

Parameters:

device – A single Device.

Example

>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0])
property device_set: Set[Device]#

A set of global devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

devices_indices_map(global_shape)[source]#

A global mapping from device to the slice of the global data it contains.

The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.

Parameters:

global_shape (Tuple[int, ...]) –

Return type:

Mapping[Device, Tuple[slice, ...]]

class jax.sharding.PartitionSpec(*partitions)[source]#

Tuple describing how to partition tensor into mesh .

Each element is either None, string or a tuple of strings. See``NamedSharding`` class for more details.

We create a separate class for this so JAX’s pytree utilities can distinguish it from a tuple that should be treated as a pytree.

class jax.sharding.Mesh(devices, axis_names)[source]#

Declare the hardware resources available in the scope of this manager.

In particular, all axis_names become valid resource names inside the managed block and can be used e.g. in the in_axis_resources argument of jax.experimental.pjit.pjit(). Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)

If you are compiling in multiple threads, make sure that the with Mesh context manager is inside the function that the threads will execute.

Parameters:
  • devices (Union[ndarray, Sequence[Device]]) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from jax.devices()).

  • axis_names (Union[str, Sequence[Any]]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example

>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)