jax.experimental.mesh_utils module

Contents

jax.experimental.mesh_utils module#

Utils for building a device mesh.

API#

create_device_mesh(mesh_shape[, devices, ...])

Creates a performant device mesh for jax.sharding.Mesh.

create_hybrid_device_mesh(mesh_shape, ...[, ...])

Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.