JAX Glossary of Terms

JAX Glossary of Terms#


JAX’s analog of numpy.ndarray. See jax.Array.


Short for Central Processing Unit, CPUs are the standard computational architecture available in most computers. JAX can run computations on CPUs, but often can achieve much better performance on GPU and TPU.


The generic name used to refer to the CPU, GPU, or TPU used by JAX to perform computations.

forward-mode autodiff#


functional programming#

A programming paradigm in which programs are defined by applying and composing pure functions. JAX is designed for use with functional programs.


Short for Graphical Processing Unit, GPUs were originally specialized for operations related to rendering of images on screen, but now are much more general-purpose. JAX is able to target GPUs for fast operations on arrays (see also CPU and TPU).


Short for JAX Expression, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to XLA for compilation and execution. See Understanding Jaxprs for more discussion and examples.


Short for Just In Time compilation, JIT in JAX generally refers to the compilation of array operations to XLA, most often accomplished using jax.jit().


Short for Jacobian Vector Product, also sometimes known as forward-mode automatic differentiation. For more details, see Jacobian-Vector products (JVPs, aka forward-mode autodiff). In JAX, JVP is a transformation that is implemented via jax.jvp(). See also VJP.


A primitive is a fundamental unit of computation used in JAX programs. Most functions in jax.lax represent individual primitives. When representing a computation in a jaxpr, each operation in the jaxpr is a primitive.

pure function#

A pure function is a function whose outputs are based only on its inputs, and which has no side-effects. JAX’s transformation model is designed to work with pure functions. See also functional programming.


A pytree is an abstraction that lets JAX handle tuples, lists, dicts, and other more general containers of array values in a uniform way. Refer to Working with pytrees for a more detailed discussion.

reverse-mode autodiff#

See VJP.


Short for Single Program Multi Data, it refers to a parallel computation technique in which the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs). jax.pmap() is a JAX transformation that implements SPMD parallelism.


In a JIT compilation, a value that is not traced (see Tracer). Also sometimes refers to compile-time computations on static values.


Short for Tensor Processing Unit, TPUs are chips specifically engineered for fast operations on N-dimensional tensors used in deep learning applications. JAX is able to target TPUs for fast operations on arrays (see also CPU and GPU).


An object used as a standin for a JAX Array in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this via the jax.core.Tracer class.


A higher-order function: that is, a function that takes a function as input and outputs a transformed function. Examples in JAX include jax.jit(), jax.vmap(), and jax.grad().


Short for Vector Jacobian Product, also sometimes known as reverse-mode automatic differentiation. For more details, see Vector-Jacobian products (VJPs, aka reverse-mode autodiff). In JAX, VJP is a transformation that is implemented via jax.vjp(). See also JVP.


Short for Accelerated Linear Algebra, XLA is a domain-specific compiler for linear algebra operations that is the primary backend for JIT-compiled JAX code. See https://www.tensorflow.org/xla/.

weak type#

A JAX data type that has the same type promotion semantics as Python scalars; see Weakly-typed values in JAX.