jax.distributed.initialize#
- jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None)[source]#
Initializes the JAX distributed system.
Calling
initialize()
prepares JAX for execution on multi-host GPU and Cloud TPU.initialize()
must be called before performing any JAX computations.The JAX distributed system serves a number of roles:
it allows JAX processes to discover each other and share topology information,
it performs health checking, ensuring that all processes shut down if any process dies, and
it is used for distributed checkpointing.
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they will be chosen automatically.
Otherwise, you must provide the
coordinator_address
,num_processes
, andprocess_id
arguments toinitialize()
.- Parameters
coordinator_address (
Optional
[str
]) – the IP address of process 0 and a port on which that process should launch a coordinator service. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port. May beNone
only on supported environments, in which case it will be chosen automatically. Note that special addresses likelocalhost
or127.0.0.1
usually mean that the program will bind to a local interface and are not suitable when running in a multi-host environment.num_processes (
Optional
[int
]) – Number of processes. May beNone
only on supported environments, in which case it will be chosen automatically.process_id (
Optional
[int
]) – The ID number of the current process. Theprocess_id
values across the cluster must be a dense range0
,1
, …,num_processes - 1
. May beNone
only on supported environments; ifNone
it will be chosen automatically.local_device_ids (
Union
[None
,int
,Sequence
[int
]]) – Restricts the visible devices of the current process tolocal_device_ids
. IfNone
, defaults to all local devices being visible to the process except when processes are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.
- Raises
RuntimeError – If
initialize()
is called more than once.
Example:
Suppose there are two GPU processs, and process 0 is the designated coordinator with address
10.0.0.1:1234
. To initialize the GPU cluster, run the following commands before anything else.On process 0:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)
On process 1:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)