jax.distributed.initialize

Contents

jax.distributed.initialize#

jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, coordinator_bind_address=None)[source]#

Initializes the JAX distributed system.

Calling initialize() prepares JAX for execution on multi-host GPU and Cloud TPU. initialize() must be called before performing any JAX computations.

The JAX distributed system serves a number of roles:

  • It allows JAX processes to discover each other and share topology information,

  • It performs health checking, ensuring that all processes shut down if any process dies, and

  • It is used for distributed checkpointing.

If you are using TPU, Slurm, or Open MPI, all arguments are optional: if omitted, they will be chosen automatically.

The cluster_detection_method may be used to choose a specific method for detecting those distributed arguments. You may pass any of the automatic spec_detect_methods to this argument though it is not necessary in the TPU, Slurm, or Open MPI cases. For other MPI installations, if you have a functional mpi4py installed, you may pass cluster_detection_method="mpi4py" to bootstrap the required arguments.

Otherwise, you must provide the coordinator_address, num_processes, and process_id arguments to initialize().

Please note: on some systems, particularly HPC clusters that only access external networks through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to initialize() may timeout. You may need to unset these variables prior to application launch.

Parameters:
  • coordinator_address (str | None | None) – the IP address of process 0 and a port on which that process should launch a coordinator service. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port. May be None only on supported environments, in which case it will be chosen automatically. Note that special addresses like localhost or 127.0.0.1 usually mean that the program will bind to a local interface and are not suitable when running in a multi-host environment.

  • num_processes (int | None | None) – Number of processes. May be None only on supported environments, in which case it will be chosen automatically.

  • process_id (int | None | None) – The ID number of the current process. The process_id values across the cluster must be a dense range 0, 1, …, num_processes - 1. May be None only on supported environments; if None it will be chosen automatically.

  • local_device_ids (int | Sequence[int] | None | None) – Restricts the visible devices of the current process to local_device_ids. If None, defaults to all local devices being visible to the process except when processes are launched via Slurm and Open MPI on GPUs. In that case, it will default to a single device per process.

  • cluster_detection_method (str | None | None) – An optional string to attempt to autodetect the configuration of the distributed run. Note that “mpi4py” method requires you to have a working mpi4py install in your environment, and launch the applicatoin with an MPI-compatible job launcher such as mpiexec or mpirun. Legacy auto-detect options (OMPI, Slurm) remain enabled.

  • initialization_timeout (int) – Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins.

  • coordinator_bind_address (str | None | None) – the address and port to which the coordinator service on process 0 should bind. If this is not specified, the default is to bind to all available addresses on the same port as coordinator_address. On systems that have multiple network interfaces per node it may be insufficient to only have the coordinator service listen on one address/interface.

Raises:

RuntimeError – If initialize() is called more than once or if called after the backend is already initialized.

Examples:

Suppose there are two GPU processes, and process 0 is the designated coordinator with address 10.0.0.1:1234. To initialize the GPU cluster, run the following commands before anything else.

On process 0:

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)  

On process 1:

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)