jax.experimental.mesh_utils.create_device_mesh

jax.experimental.mesh_utils.create_device_mesh#

jax.experimental.mesh_utils.create_device_mesh(mesh_shape, devices=None, *, contiguous_submeshes=False, allow_split_physical_axes=False)[source]#

Creates a performant device mesh for jax.sharding.Mesh.

Parameters:
  • mesh_shape (Sequence[int]) – shape of logical mesh, ordered by increasing network-intensity e.g. [replica, data, mdl] where mdl has the most network communication requirements.

  • devices (Sequence[Any] | None) – optionally, the devices to construct a mesh for. Defaults to jax.devices().

  • contiguous_submeshes (bool) – if True, this function will attempt to create a mesh where each process’s local devices form a contiguous submesh. A ValueError will be raised if this function can’t produce a suitable mesh. This setting was sometimes necessary before the introduction of jax.Array to ensure non-ragged local arrays; if using jax.Arrays, it’s better to keep this set to False.

  • allow_split_physical_axes (bool) – If True, we will split physical axes if necessary to produce the desired device mesh.

Raises:

ValueError – if the number of devices doesn’t equal the product of mesh_shape.

Returns:

A np.ndarray of JAX devices with mesh_shape as its shape that can be fed into jax.sharding.Mesh with good collective performance.

Return type:

ndarray