GPU performance tips#

This document focuses on performance tips for neural network workloads

Matmul precision#

On recent GPU generations, such as the Nvidia A100 generation or later, it can be a good idea to perform most computations in bfloat16 precision. For example, if using Flax, instantiate Dense layers using flax.linen.Dense(..., dtype=jax.numpy.bfloat16). Here are some code examples:

XLA performance flags#

The existence and exact behavior of XLA flags may be jaxlib-version dependent.

As of jaxlib==0.4.18 (released Oct 6 2023), setting these XLA flags can improve performance. Some are related to communication between GPUs, and so are only relevant when running computations on multiple devices, while others are related to code generation on each device.

Some of these may be set by default in future releases.

These flags can be set via the XLA_FLAGS shell environment variable. For example, we can add this to the top of a Python file:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
)

For more examples, see also XLA Flags recommended for Pax training on Nvidia GPUs.

Code generation flags#

  • –xla_gpu_enable_triton_softmax_fusion This flag enables an automatic softmax fusion, based on pattern-matching backed by Triton code generation. The default value is False.

  • –xla_gpu_triton_gemm_any Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False.

Communication flags#

  • –xla_gpu_enable_async_collectives This flag enables the collective ops such as AllReduce, AllGather, ReduceScatter and CollectivePermute to be asynchronous. Asynchronous communication can overlap cross-core communication with computation. The default value is False.

  • –xla_gpu_enable_latency_hiding_scheduler This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.

  • –xla_gpu_enable_pipelined_collectives When using pipeline parallelism, this flag enables overlapping the (i+1)-th layer weight AllGather with the i-th layer computation. It also enables overlapping (i+1)-th layer weight Reduce/ReduceScatter with i-th layer’s computation. The default value is False. There are some bugs when this flag is turned on.

  • –xla_gpu_collective_permute_decomposer_threshold This flag is useful when performing GSPMD pipelining. Setting a nonzero threshold decomposes CollectivePermutes into CollectivePermuteReceiveDone and CollectivePermuteSendDone pairs, so that computation can be performed between each corresponding ReceiveDone/SendDone pair and hence achieve more overlap. By default the threshold is 0 and there is no decomposition. Setting it to threshold > 0 such as --xla_gpu_collective_permute_decomposer_threshold=1024 can enable this feature.

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes These flags tune when to combine multiple small AllGather/ReduceScatter/AllReduce into one big AllGather/ReduceScatter/AllReduce to reduce time spent on cross-device communication. For example, for the AllGather/ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather/ReduceScatter. By default, the combine_threshold_bytes is set to 256.

NCCL flags#

These Nvidia NCCL flag values may be useful for single-host multi-device computations on Nvidia GPUs:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

These NCCL flags could improve single-host communication speed. These flags don’t seem useful for multi-host communication yet.