JAX reference documentation¶
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more.
For an introduction to JAX, start at the JAX GitHub page.
Getting Started
Advanced JAX Tutorials
- Convolutions in JAX
- The Autodiff Cookbook
- Autobatching log-densities example
- Training a Simple Neural Network, with tensorflow/datasets Data Loading
- Custom derivative rules for JAX-transformable Python functions
- How JAX primitives work
- Writing custom Jaxpr interpreters in JAX
- Training a Simple Neural Network, with PyTorch Data Loading
- XLA in Python
- MAML Tutorial with JAX
- Generative Modeling by Estimating Gradients of Data Distribution in JAX
Notes
- Change Log
- JAX Frequently Asked Questions (FAQ)
- JAX Errors
- Understanding Jaxprs
- Asynchronous dispatch
- Concurrency
- GPU memory allocation
- Profiling JAX programs
- Device Memory Profiling
- Pytrees
- Rank promotion warning
- Type promotion semantics
custom_vjp
andnondiff_argnums
update guide- JAX Glossary of Terms
Developer documentation