logo

Getting Started

  • Installing JAX
  • JAX Quickstart
  • How to Think in JAX
  • 🔪 JAX - The Sharp Bits 🔪
  • Tutorial: JAX 101
    • JAX As Accelerated NumPy
    • Just In Time Compilation with JAX
    • Automatic Vectorization in JAX
    • Advanced Automatic Differentiation in JAX
    • Pseudo Random Numbers in JAX
    • Working with Pytrees
    • Parallel Evaluation in JAX
    • Stateful Computations in JAX
    • Introduction to pjit

Reference Documentation

  • JAX Frequently Asked Questions (FAQ)
  • Asynchronous dispatch
  • Understanding Jaxprs
  • Convolutions in JAX
  • Pytrees
  • Type promotion semantics
  • JAX Errors
  • JAX Glossary of Terms
  • Change log

Advanced JAX Tutorials

  • The Autodiff Cookbook
  • Autobatching log-densities example
  • Training a Simple Neural Network, with tensorflow/datasets Data Loading
  • Custom derivative rules for JAX-transformable Python functions
  • How JAX primitives work
  • Writing custom Jaxpr interpreters in JAX
  • Training a Simple Neural Network, with PyTorch Data Loading
  • Named axes and easy-to-revise parallelism
  • Using JAX in multi-host and multi-process environments

Notes

  • API compatibility
  • Python and NumPy version support policy
  • Concurrency
  • GPU memory allocation
  • Profiling JAX programs
  • Device Memory Profiling
  • Rank promotion warning
  • custom_vjp and nondiff_argnums update guide
  • Transfer guard

Developer documentation

  • Contributing to JAX
  • Building from source
  • Internal APIs
  • Autodidax: JAX core from scratch
  • Design Notes
    • Custom JVP/VJP rules for JAX-transformable functions
    • Jax and Jaxlib versioning
    • Omnistaging
    • JAX PRNG Design
    • Design of Type Promotion Semantics for JAX
    • Sequencing side-effects in JAX

API documentation

  • Public API: jax package
    • jax.numpy package
    • jax.scipy package
    • JAX configuration
    • jax.dlpack module
    • jax.distributed module
    • jax.example_libraries package
      • jax.example_libraries.optimizers module
      • jax.example_libraries.stax module
    • jax.experimental package
      • jax.experimental.global_device_array module
      • jax.experimental.host_callback module
      • jax.experimental.loops module
      • jax.experimental.maps module
      • jax.experimental.pjit module
      • jax.experimental.sparse module
      • jax.experimental.jet module
    • jax.flatten_util package
    • jax.image package
    • jax.lax package
    • jax.nn package
      • jax.nn.initializers package
    • jax.ops package
    • jax.profiler module
    • jax.random package
    • jax.tree_util package
    • jax.lib package
Theme by the Executable Book Project
  • .rst

Design Notes

Design Notes#

  • Custom JVP/VJP rules for JAX-transformable functions
  • Jax and Jaxlib versioning
  • Omnistaging
  • JAX PRNG Design
  • Design of Type Promotion Semantics for JAX
  • Sequencing side-effects in JAX

previous

Autodidax: JAX core from scratch

next

Custom JVP/VJP rules for JAX-transformable functions

By The JAX authors
© Copyright 2020, Google LLC. NumPy and SciPy documentation are copyright the respective authors..