jax.sharding
module#
Classes#
- class jax.sharding.Sharding#
Abstract
Sharding
interface which describes how ajax.Array
is laid out across devices.- property addressable_devices: Set[Device]#
A set of devices that are addressable by the current process.
- addressable_devices_indices_map(global_shape)[source]#
A mapping from addressable device to the slice of global data it contains.
addressable_devices_indices_map
contains that part ofdevice_indices_map
that applies to the addressable devices.
- property device_set: Set[Device]#
A
set
of global devices that thisSharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- devices_indices_map(global_shape)[source]#
A global mapping from device to the slice of the global data it contains.
The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.
- is_equivalent_to(other, ndim)[source]#
Returns True if two shardings put the same logical array (sharded/unsharded) on the same device(s).
For example, every XLACompatibleSharding lowers to GSPMDSharding which is a general representation. So jax.sharding.NamedSharding is equivalent to jax.sharding.PositionalSharding if both of them lower to the same GSPMDSharding.
- class jax.sharding.XLACompatibleSharding#
Bases:
Sharding
A Sharding that describes shardings expressible to XLA.
Any
Sharding
that is a subclass ofXLACompatibleSharding
will work with all JAX APIs and transformations that use XLA.- devices_indices_map(global_shape)[source]#
A global mapping from device to the slice of the global data it contains.
The devices in this mapping are global devices i.e. includes non-addressable devices from other processes.
- is_equivalent_to(other, ndim)[source]#
Returns True if two shardings put the same logical array (sharded/unsharded) on the same device(s).
For example, every XLACompatibleSharding lowers to GSPMDSharding which is a general representation. So jax.sharding.NamedSharding is equivalent to jax.sharding.PositionalSharding if both of them lower to the same GSPMDSharding.
- Parameters:
self (
XLACompatibleSharding
) –other (
XLACompatibleSharding
) –ndim (
int
) –
- Return type:
- class jax.sharding.NamedSharding#
Bases:
XLACompatibleSharding
NamedSharding is a way to express
Sharding
s using named axes.Mesh
andPartitionSpec
can be used to express aSharding
with a name.Mesh
is a NumPy array of JAX devices in a multi-dimensional grid, where each axis of the mesh has a name, e.g. ‘x’ or ‘y’. Another name forMesh
is “logical mesh”.PartitionSpec
is a tuple, whose elements can be aNone
, a mesh axis or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, PartitionSpec(‘x’, ‘y’) is a PartitionSpec where the first dimension of data is sharded acrossx
axis of the mesh, and the second dimension is sharded acrossy
axis of the mesh.The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) goes into more details and has diagrams to help explain the concept about
Mesh
andPartitionSpec
.- Parameters:
mesh – A
jax.sharding.Mesh
object.spec – A
jax.sharding.PartitionSpec
object.
Example
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- class jax.sharding.SingleDeviceSharding#
Bases:
XLACompatibleSharding
A subclass of
XLACompatibleSharding
that places its data on a single device.- Parameters:
device – A single
Device
.
Example
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- property device_set: Set[Device]#
A
set
of global devices that thisSharding
spans.In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
- class jax.sharding.PartitionSpec(*partitions)[source]#
Tuple describing how to partition tensor into mesh .
Each element is either None, string or a tuple of strings. See``NamedSharding`` class for more details.
We create a separate class for this so JAX’s pytree utilities can distinguish it from a tuple that should be treated as a pytree.
- class jax.sharding.Mesh(devices, axis_names)[source]#
Declare the hardware resources available in the scope of this manager.
In particular, all
axis_names
become valid resource names inside the managed block and can be used e.g. in thein_axis_resources
argument ofjax.experimental.pjit.pjit()
. Also see JAX’s multi-process programming model (https://jax.readthedocs.io/en/latest/multi_process.html) and the Distributed arrays and automatic parallelization tutorial (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)If you are compiling in multiple threads, make sure that the
with Mesh
context manager is inside the function that the threads will execute.- Parameters:
devices (
Union
[ndarray
,Sequence
[Device
]]) – A NumPy ndarray object containing JAX device objects (as obtained e.g. fromjax.devices()
).axis_names (
Union
[str
,Sequence
[Any
]]) – A sequence of resource axis names to be assigned to the dimensions of thedevices
argument. Its length should match the rank ofdevices
.
Example
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)