JAX Glossary of Terms¶
Short for Central Processing Unit, CPUs are the standard computational architecture available in most computers. JAX can run computations on CPUs, but often can achieve much better performance on GPU and TPU.
JAX’s analog of the
- forward-mode autodiff¶
- functional programming¶
A programming paradigm in which programs are defined by applying and composing pure functions. JAX is designed for use with functional programs.
Short for Graphical Processing Unit, GPUs were originally specialized for operations related to rendering of images on screen, but now are much more general-purpose. JAX is able to target GPUs for fast operations on arrays (see also CPU and TPU).
Short for JAX Expression, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to XLA for compilation and execution. See Understanding Jaxprs for more information.
Short for Jacobian Vector Product, also sometimes known as forward-mode automatic differentiation. For more details, see Jacobian-Vector products (JVPs, aka forward-mode autodiff). In JAX, JVP is a transformation that is implemented via
jax.jvp(). See also VJP.
- pure function¶
A pure function is a function whose outputs are based only on its inputs, and which has no side-effects. JAX’s transformation model is designed to work with pure functions. See also functional programming.
- reverse-mode autodiff¶
Short for Single Program Multi Data, it refers to a parallel computation technique in which the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).
jax.pmap()is a JAX transformation that implements SPMD parallelism.
Short for Tensor Processing Unit, TPUs are chips specifically engineered for fast operations on N-dimensional tensors used in deep learning applications. JAX is able to target TPUs for fast operations on arrays (see also CPU and GPU).
An object used as a standin for a JAX DeviceArray in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this via the
Short for Vector Jacobian Product, also sometimes known as reverse-mode automatic differentiation. For more details, see Vector-Jacobian products (VJPs, aka reverse-mode autodiff). In JAX, VJP is a transformation that is implemented via
jax.vjp(). See also JVP.