Logo
latest

Getting Started

  • JAX Quickstart
  • How to Think in JAX
  • 🔪 JAX - The Sharp Bits 🔪
  • Tutorial: JAX 101

Reference Documentation

  • JAX Frequently Asked Questions (FAQ)
  • Transformations
  • Asynchronous dispatch
  • Understanding Jaxprs
  • Convolutions in JAX
  • Pytrees
  • Type promotion semantics
  • JAX Errors
  • JAX Glossary of Terms
  • Change log

Advanced JAX Tutorials

  • The Autodiff Cookbook
  • Autobatching log-densities example
  • Training a Simple Neural Network, with tensorflow/datasets Data Loading
  • Custom derivative rules for JAX-transformable Python functions
  • How JAX primitives work
  • Writing custom Jaxpr interpreters in JAX
  • Training a Simple Neural Network, with PyTorch Data Loading
  • XLA in Python
  • MAML Tutorial with JAX
  • Generative Modeling by Estimating Gradients of Data Distribution in JAX
  • Named axes and easy-to-revise parallelism

Notes

  • Concurrency
  • GPU memory allocation
  • Profiling JAX programs
  • Device Memory Profiling
  • Rank promotion warning
  • custom_vjp and nondiff_argnums update guide

Developer documentation

  • Building from source
  • Running the tests
  • Type checking
  • Update documentation
  • Internal APIs
  • Autodidax: JAX core from scratch

API documentation

  • Public API: jax package
JAX
  • Docs »
  • Overview: module code

All modules for which code is available

  • jax._src.api
  • jax._src.dtypes
  • jax._src.errors
  • jax._src.image.scale
  • jax._src.lax.control_flow
  • jax._src.lax.fft
  • jax._src.lax.lax
  • jax._src.lax.linalg
  • jax._src.lax.other
  • jax._src.lax.parallel
  • jax._src.nn.functions
  • jax._src.nn.initializers
  • jax._src.numpy.fft
  • jax._src.numpy.lax_numpy
  • jax._src.numpy.linalg
  • jax._src.numpy.polynomial
  • jax._src.numpy.vectorize
  • jax._src.ops.scatter
  • jax._src.profiler
  • jax._src.random
  • jax._src.scipy.linalg
  • jax._src.scipy.ndimage
  • jax._src.scipy.optimize.minimize
  • jax._src.scipy.signal
  • jax._src.scipy.sparse.linalg
  • jax._src.scipy.special
  • jax._src.scipy.stats.bernoulli
  • jax._src.scipy.stats.beta
  • jax._src.scipy.stats.betabinom
  • jax._src.scipy.stats.cauchy
  • jax._src.scipy.stats.chi2
  • jax._src.scipy.stats.dirichlet
  • jax._src.scipy.stats.expon
  • jax._src.scipy.stats.gamma
  • jax._src.scipy.stats.geom
  • jax._src.scipy.stats.laplace
  • jax._src.scipy.stats.logistic
  • jax._src.scipy.stats.multivariate_normal
  • jax._src.scipy.stats.norm
  • jax._src.scipy.stats.pareto
  • jax._src.scipy.stats.poisson
  • jax._src.scipy.stats.t
  • jax._src.scipy.stats.uniform
  • jax._src.third_party.numpy.linalg
  • jax._src.tree_util
  • jax.core
  • jax.custom_derivatives
  • jax.experimental.host_callback
  • jax.experimental.loops
  • jax.experimental.maps
  • jax.experimental.optimizers
  • jax.experimental.stax
  • jax.experimental.x64_context
  • jax.lib.xla_bridge
  • jaxlib.xla_extension
    • jaxlib.xla_extension.profiler
  • numpy.core.getlimits

© Copyright 2020, Google LLC. NumPy and SciPy documentation are copyright the respective authors. Revision 6f0f7174.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
4510-2
test-docs
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.