jax.distributed.initialize#
- jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, coordinator_bind_address=None)[source]#
Initializes the JAX distributed system.
Calling
initialize()
prepares JAX for execution on multi-host GPU and Cloud TPU.initialize()
must be called before performing any JAX computations.The JAX distributed system serves a number of roles:
It allows JAX processes to discover each other and share topology information,
It performs health checking, ensuring that all processes shut down if any process dies, and
It is used for distributed checkpointing.
If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they will be chosen automatically.
The
cluster_detection_method
may be used to choose a specific method for detecting those distributed arguments. You may pass any of the automaticspec_detect_methods
to this argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI installations, if you have a functionalmpi4py
installed, you may passcluster_detection_method="mpi4py"
to bootstrap the required arguments.Otherwise, you must provide the
coordinator_address
,num_processes
, andprocess_id
arguments toinitialize()
.Please note: on some systems, particularly HPC clusters that only access external networks through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to
initialize()
may timeout. You may need to unset these variables prior to application launch.- Parameters:
coordinator_address (str | None | None) – the IP address of process 0 and a port on which that process should launch a coordinator service. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port. May be
None
only on supported environments, in which case it will be chosen automatically. Note that special addresses likelocalhost
or127.0.0.1
usually mean that the program will bind to a local interface and are not suitable when running in a multi-host environment.num_processes (int | None | None) – Number of processes. May be
None
only on supported environments, in which case it will be chosen automatically.process_id (int | None | None) – The ID number of the current process. The
process_id
values across the cluster must be a dense range0
,1
, …,num_processes - 1
. May beNone
only on supported environments; ifNone
it will be chosen automatically.local_device_ids (int | Sequence[int] | None | None) – Restricts the visible devices of the current process to
local_device_ids
. IfNone
, defaults to all local devices being visible to the process except when processes are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.cluster_detection_method (str | None | None) – An optional string to attempt to autodetect the configuration of the distributed run. Note that “mpi4py” method requires you to have a working
mpi4py
install in your environment, and launch the applicatoin with an MPI-compatible job launcher such asmpiexec
ormpirun
. Legacy auto-detect options (OMPI, Slurm) remain enabled.initialization_timeout (int) – Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins.
coordinator_bind_address (str | None | None) – the address and port to which the coordinator service on process 0 should bind. If this is not specified, the default is to bind to all available addresses on the same port as
coordinator_address
. On systems that have multiple network interfaces per node it may be insufficient to only have the coordinator service listen on one address/interface.
- Raises:
RuntimeError – If
initialize()
is called more than once or if called after the backend is already initialized.
Examples:
Suppose there are two GPU processes, and process 0 is the designated coordinator with address
10.0.0.1:1234
. To initialize the GPU cluster, run the following commands before anything else.On process 0:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)
On process 1:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)