jax.distributed.initialize#

jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None)[source]#

Initializes the JAX distributed system.

Calling initialize() prepares JAX for execution on multi-host GPU and Cloud TPU. initialize() must be called before performing any JAX computations.

The JAX distributed system serves a number of roles:

  • it allows JAX processes to discover each other and share topology information,

  • it performs health checking, ensuring that all processes shut down if any process dies, and

  • it is used for distributed checkpointing.

If you are using TPU or Slurm, all arguments are optional: if omitted, they will be chosen automatically.

Otherwise, you must provide the coordinator_address, num_processes, and process_id arguments to initialize().

Parameters
  • coordinator_address (Optional[str]) – the IP address of process 0 and a port on which that process should launch a coordinator service. The choice of port does not matter, so long as the port is available on the coordinator and all processes agree on the port. May be None only on supported environments, in which case it will be chosen automatically.

  • num_processes (Optional[int]) – Number of processes. May be None only on supported environments, in which case it will be chosen automatically.

  • process_id (Optional[int]) – The ID number of the current process. The process_id values across the cluster must be a dense range 0, 1, …, num_processes - 1. May be None only on supported environments; if None it will be chosen automatically.

  • local_device_ids (Union[int, Sequence[int], None]) – Restricts the visible devices of the current process to local_device_ids. If None, defaults to all local devices being visible to the process except when processes are launched via Slurm on GPUs. In that case, it will default to a single device per process.

Raises

RuntimeError – If initialize() is called more than once.

Example:

Suppose there are two GPU processs, and process 0 is the designated coordinator with address 10.0.0.1:1234. To initialize the GPU cluster, run the following commands before anything else.

On process 0:

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)  

On process 1:

>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)