JAX Glossary of Terms

CPU

Short for Central Processing Unit, CPUs are the standard computational architecture available in most computers. JAX can run computations on CPUs, but often can achieve much better performance on GPU and TPU.

Device

The generic name used to refer to the CPU, GPU, or TPU used by JAX to perform computations.

DeviceArray

JAX’s analog of the numpy.ndarray. See jax.interpreters.xla.DeviceArray.

forward-mode autodiff

See JVP

functional programming

A programming paradigm in which programs are defined by applying and composing pure functions. JAX is designed for use with functional programs.

GPU

Short for Graphical Processing Unit, GPUs were originally specialized for operations related to rendering of images on screen, but now are much more general-purpose. JAX is able to target GPUs for fast operations on arrays (see also CPU and TPU).

jaxpr

Short for JAX Expression, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to XLA for compilation and execution. See Understanding Jaxprs for more information.

JIT

Short for Just In Time compilation, JIT in JAX generally refers to the compilation of array operations to XLA, most often accomplished using jax.jit().

JVP

Short for Jacobian Vector Product, also sometimes known as forward-mode automatic differentiation. For more details, see Jacobian-Vector products (JVPs, aka forward-mode autodiff). In JAX, JVP is a transformation that is implemented via jax.jvp(). See also VJP.

pure function

A pure function is a function whose outputs are based only on its inputs, and which has no side-effects. JAX’s transformation model is designed to work with pure functions. See also functional programming.

reverse-mode autodiff

See VJP.

SPMD

Short for Single Program Multi Data, it refers to a parallel computation technique in which the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs). jax.pmap() is a JAX transformation that implements SPMD parallelism.

static

In a JIT compilation, a value that is not traced (see Tracer). Also sometimes refers to compile-time computations on static values.

TPU

Short for Tensor Processing Unit, TPUs are chips specifically engineered for fast operations on N-dimensional tensors used in deep learning applications. JAX is able to target TPUs for fast operations on arrays (see also CPU and GPU).

Tracer

An object used as a standin for a JAX DeviceArray in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this via the jax.core.Tracer class.

transformation

A higher-order function: that is, a function that takes a function as input and outputs a transformed function. Examples in JAX include jax.jit(), jax.vmap(), and jax.grad().

VJP

Short for Vector Jacobian Product, also sometimes known as reverse-mode automatic differentiation. For more details, see Vector-Jacobian products (VJPs, aka reverse-mode autodiff). In JAX, VJP is a transformation that is implemented via jax.vjp(). See also JVP.

XLA

Short for Accelerated Linear Algebra, XLA is a domain-specific compiler for linear algebra operations that is the primary backend for JIT-compiled JAX code. See https://www.tensorflow.org/xla/.

weak type

A JAX data type that has the same type promotion semantics as Python scalars; see Weakly-typed values in JAX.