- jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None)[source]#
Initializes the JAX distributed system.
initialize()prepares JAX for execution on multi-host GPU and Cloud TPU.
initialize()must be called before performing any JAX computations.
The JAX distributed system serves a number of roles:
it allows JAX processes to discover each other and share topology information,
it performs health checking, ensuring that all processes shut down if any process dies, and
it is used for distributed checkpointing.
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they will be chosen automatically.
Otherwise, you must provide the
str]) – the IP address of process 0 and a port on which that process should launch a coordinator service. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port. May be
Noneonly on supported environments, in which case it will be chosen automatically. Note that special addresses like
127.0.0.1usually mean that the program will bind to a local interface and are not suitable when running in a multi-host environment.
int]) – Number of processes. May be
Noneonly on supported environments, in which case it will be chosen automatically.
int]) – The ID number of the current process. The
process_idvalues across the cluster must be a dense range
num_processes - 1. May be
Noneonly on supported environments; if
Noneit will be chosen automatically.
int]]) – Restricts the visible devices of the current process to
None, defaults to all local devices being visible to the process except when processes are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
RuntimeError – If
initialize()is called more than once.
Suppose there are two GPU processs, and process 0 is the designated coordinator with address
10.0.0.1:1234. To initialize the GPU cluster, run the following commands before anything else.
On process 0:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)
On process 1:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)