Logo
latest

Getting Started

  • JAX Quickstart
  • How to Think in JAX
  • 🔪 JAX - The Sharp Bits 🔪
  • Tutorial: JAX 101

Reference Documentation

  • JAX Frequently Asked Questions (FAQ)
  • Transformations
  • Asynchronous dispatch
  • Understanding Jaxprs
  • Convolutions in JAX
  • Pytrees
  • Type promotion semantics
  • JAX Errors
  • JAX Glossary of Terms
  • Change log

Advanced JAX Tutorials

  • The Autodiff Cookbook
  • Autobatching log-densities example
  • Training a Simple Neural Network, with tensorflow/datasets Data Loading
  • Custom derivative rules for JAX-transformable Python functions
  • How JAX primitives work
  • Writing custom Jaxpr interpreters in JAX
  • Training a Simple Neural Network, with PyTorch Data Loading
  • XLA in Python
  • MAML Tutorial with JAX
  • Generative Modeling by Estimating Gradients of Data Distribution in JAX
  • Named axes and easy-to-revise parallelism

Notes

  • Concurrency
  • GPU memory allocation
  • Profiling JAX programs
  • Device Memory Profiling
  • Rank promotion warning
  • custom_vjp and nondiff_argnums update guide

Developer documentation

  • Building from source
  • Running the tests
  • Type checking
  • Update documentation
  • Internal APIs
  • Autodidax: JAX core from scratch

API documentation

  • Public API: jax package
JAX
  • Docs »
  • Python Module Index

Python Module Index

j
 
j
- jax
    jax.core
    jax.dlpack
    jax.experimental
    jax.experimental.host_callback
    jax.experimental.loops
    jax.experimental.maps
    jax.experimental.optimizers
    jax.experimental.stax
    jax.image
    jax.lax
    jax.lax.linalg
    jax.nn
    jax.nn.initializers
    jax.numpy
    jax.numpy.fft
    jax.numpy.linalg
    jax.ops
    jax.profiler
    jax.random
    jax.scipy.linalg
    jax.scipy.ndimage
    jax.scipy.optimize
    jax.scipy.signal
    jax.scipy.sparse.linalg
    jax.scipy.special
    jax.scipy.stats.bernoulli
    jax.scipy.stats.beta
    jax.scipy.stats.betabinom
    jax.scipy.stats.cauchy
    jax.scipy.stats.chi2
    jax.scipy.stats.dirichlet
    jax.scipy.stats.expon
    jax.scipy.stats.gamma
    jax.scipy.stats.geom
    jax.scipy.stats.laplace
    jax.scipy.stats.logistic
    jax.scipy.stats.multivariate_normal
    jax.scipy.stats.norm
    jax.scipy.stats.pareto
    jax.scipy.stats.poisson
    jax.scipy.stats.t
    jax.scipy.stats.uniform
    jax.tree_util

© Copyright 2020, Google LLC. NumPy and SciPy documentation are copyright the respective authors. Revision 6f0f7174.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
4510-2
test-docs
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.