jax.sharding package#

Classes#

class jax.sharding.Sharding#

Abstract Sharding interface which describes how a jax.Array is laid out across devices.

property addressable_devices: Set[jaxlib.xla_extension.Device]#

A set of devices that are addressable by the current process.

addressable_devices_indices_map(global_shape)[source]#

A mapping from addressable device to the slice of global data it contains.

addressable_devices_indices_map contains that part of device_indices_map that applies to the addressable devices.

Parameters

global_shape (Tuple[int, ...]) –

Return type

Mapping[Device, Optional[Tuple[slice, ...]]]

abstract property device_set: Set[jaxlib.xla_extension.Device]#

A set of global devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

Return type

Set[Device]

abstract devices_indices_map(global_shape)[source]#

A global mapping from device to the slice of the global data it contains.

The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.

Parameters

global_shape (Tuple[int, ...]) –

Return type

Mapping[Device, Optional[Tuple[slice, ...]]]

property is_fully_addressable: bool#

True if the current process can address all of the devices in device_set.

abstract shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated from the global shape (it takes as an input) and the properties of the sharding.

Parameters

global_shape (Tuple[int, ...]) –

Return type

Tuple[int, ...]

class jax.sharding.XLACompatibleSharding#

Bases: jaxlib.xla_extension.Sharding

A Sharding that describes shardings expressible to XLA.

Any Sharding that is a subclass of XLACompatibleSharding will work with all JAX APIs and transformations that use XLA.

devices_indices_map(global_shape)[source]#

A global mapping from device to the slice of the global data it contains.

The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.

Parameters

global_shape (Tuple[int, ...]) –

Return type

Mapping[Device, Tuple[slice, ...]]

shard_shape(global_shape)[source]#

Returns the shape of the data on each device.

The shard shape returned by this function is calculated from the global shape (it takes as an input) and the properties of the sharding.

Parameters

global_shape (Tuple[int, ...]) –

Return type

Tuple[int, ...]

class jax.sharding.NamedSharding#

Bases: jaxlib.xla_extension.XLACompatibleSharding

NamedSharding is a way to express Shardings using named axes.

Mesh and PartitionSpec can be used to express a Sharding with a name.

Mesh is a NumPy array of JAX devices in a multi-dimensional grid, where each axis of the mesh has a name, e.g. β€˜x’ or β€˜y’. Another name for Mesh is β€œlogical mesh”.

PartitionSpec is a named tuple, whose elements can be a None, a mesh axis or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, PartitionSpec(β€˜x’, β€˜y’) is a PartitionSpec where the first dimension of data is sharded across x axis of the mesh, and the second dimension is sharded across y axis of the mesh.

The pjit tutorial (https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html#more-information-on-partitionspec) goes into more details and has diagrams to help explain the concept about Mesh and PartitionSpec.

Parameters
  • mesh – A jax.experimental.maps.Mesh object.

  • spec – A jax.experimental.PartitionSpec object.

Example

>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property device_set: Set[jaxlib.xla_extension.Device]#

A set of global devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

class jax.sharding.SingleDeviceSharding#

Bases: jaxlib.xla_extension.XLACompatibleSharding

A subclass of XLACompatibleSharding that places its data on a single device.

Parameters

device – A single Device.

Example

>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0])
property device_set: Set[jaxlib.xla_extension.Device]#

A set of global devices that this Sharding spans.

In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.

Return type

Set[Device]

devices_indices_map(global_shape)[source]#

A global mapping from device to the slice of the global data it contains.

The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.

Parameters

global_shape (Tuple[int, ...]) –

Return type

Mapping[Device, Tuple[slice, ...]]