List of XLA compiler flags#
Introduction#
This guide gives a brief overview of XLA and how XLA relates to Jax. For in-depth details please refer to XLA documentation. Then it lists commonly-used XLA compiler flags designed to optimize performance of Jax programs.
XLA: The Powerhouse Behind Jax#
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax’s performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions.
Jax uses XLA’s JIT compilation capabilities to transform your Python functions into optimized XLA computations at runtime.
Configuring XLA in Jax:#
You can influence XLA’s behavior in Jax by setting XLA_FLAGS environment variables before running your Python script or colab notebook.
For the colab notebooks:
Provide flags using os.environ['XLA_FLAGS']
:
import os
# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
For the python scripts:
Specify XLA_FLAGS
as a part of cli command:
XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
Important Notes:
Set
XLA_FLAGS
before importing Jax or other relevant libraries. ChangingXLA_FLAGS
after backend initialization will have no effect and given backend initialization time is not clearly defined it is usually safer to setXLA_FLAGS
before executing any Jax code.Experiment with different flags to optimize performance for your specific use case.
For further information:
Complete and up to date documentation about XLA can be found in the official XLA documentation.
For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values in xla/debug_options_flags.cc, and a complete list of flags could be found here.
TPU compiler flags are not part of OpenXLA, but commonly-used options are listed below.
Please note that this list of flags is not exhaustive and is subject to change. These flags are implementation details, and there is no guarantee that they will remain available or maintain their current behavior.
Common XLA flags#
Flag |
Type |
Notes |
---|---|---|
|
String (filepath) |
The folder where pre-optimization HLO files and other artifacts will be placed (see XLA Tools). |
|
TristateFlag (true/false/auto) |
Rewrites all collective-permute operations to their asynchronous variants. When set to |
|
TristateFlag (true/false/auto) |
If set to true, enables async all gather. If |
|
String (comma-separated list of pass names) |
Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas). |
TPU XLA flags#
Flag |
Type |
Notes |
---|---|---|
|
Boolean (true/false) |
Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. |
|
Boolean (true/false) |
Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn’t match what can Be saved in place in the stacked variables. Can increase memory pressure. |
|
Boolean (true/false) |
Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. |
|
TristateFlag (true/false/auto) |
Enables fusing all-gathers within the AsyncCollectiveFusion pass. |
|
Boolean (true/false) |
Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. |
|
Boolean (true/false) |
Enables the overlap of compute and communication on a single TensorCore, i.e., one core equivalent of MegaCore fusion. |
|
Boolean (true/false) |
Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation. |
|
Boolean (true/false) |
Allows fusing all-gathers with convolutions/all-reduces. |
|
Boolean (true/false) |
Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops. |
GPU XLA flags#
Flag |
Type |
Notes |
---|---|---|
|
Boolean (true/false) |
This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. |
|
Boolean (true/false) |
Use Triton-based matrix multiplication. |
|
Flag (0-3) |
The legacy flag for setting GPU graph level. Use xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture fusions and memcpys; 2 = capture gemms; 3 = capture convolutions. |
|
Integer (bytes) |
These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256. |
|
Integer (bytes) |
See xla_gpu_all_reduce_combine_threshold_bytes above. |
|
Integer (bytes) |
See xla_gpu_all_reduce_combine_threshold_bytes above. |
|
Boolean (true/false) |
Enable pipelinling of all-gather instructions. |
|
Boolean (true/false) |
Enable pipelinling of reduce-scatter instructions. |
|
Boolean (true/false) |
Enable pipelinling of all-reduce instructions. |
|
Boolean (true/false) |
Enable double-buffering for while loop. |
|
Boolean (true/false) |
Use Triton-based Softmax fusion. |
|
Boolean (true/false) |
Combine all-gather ops with the same gather dimension or irrespective of their dimension. |
|
Boolean (true/false) |
Combine reduce-scatter ops with the same dimension or irrespective of their dimension. |
Additional reading: