jax.experimental.mesh_utils.create_hybrid_device_mesh

jax.experimental.mesh_utils.create_hybrid_device_mesh#

jax.experimental.mesh_utils.create_hybrid_device_mesh(mesh_shape, dcn_mesh_shape, devices=None, *, process_is_granule=False, should_sort_granules_by_key=True, allow_split_physical_axes=False)[source]#

Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.

Parameters:
  • mesh_shape (Sequence[int]) – shape of the logical mesh for the faster/inner network, ordered by increasing network intensity, e.g. [replica, data, mdl] where mdl has the most network communication requirements.

  • dcn_mesh_shape (Sequence[int]) – shape of the logical mesh for the slower/outer network, in the same order as mesh_shape.

  • devices (Sequence[Any] | None) – optionally, the devices to construct a mesh for. Defaults to jax.devices().

  • process_is_granule (bool) – if True, this function will treat processes as the units of the slower/outer network. Otherwise it will look for slice_index attributes on devices and use slices as the units. Enabling this is meant as a fallback for platforms that don’t set slice_index.

  • should_sort_granules_by_key (bool) – Whether device granules should be sorted by the granule key, either slice or process index, depending on process_is_granule.

  • allow_split_physical_axes (bool) – If True, we will split physical axes if necessary to produce the desired device mesh.

Raises:

ValueError – if the number of slices to which the devices belong doesn’t equal the product of dcn_mesh_shape, or if the number of devices belonging to any single slice does not equal the product of mesh_shape.

Returns:

A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape that can be fed into jax.sharding.Mesh for hybrid parallelism.

Return type:

ndarray