jax.distributed module

jax.distributed module#

initialize([coordinator_address, ...])

Initializes the JAX distributed system.

shutdown()

Shuts down the distributed system.