test-docs
Tutorials
JAX Quickstart
The Autodiff Cookbook
Autobatching log-densities example
Training a Simple Neural Network, with Tensorflow Datasets Data Loading
Advanced JAX Tutorials
🔪 JAX - The Sharp Bits 🔪
Custom derivative rules for JAX-transformable Python functions
JAX pytrees
How JAX primitives work
Writing custom Jaxpr interpreters in JAX
Notes
Change Log
JAX Frequently Asked Questions
Understanding jaxprs
Asynchronous dispatch
Concurrency
GPU memory allocation
Profiling JAX programs
Rank promotion warning
Type promotion semantics
Developer documentation
Building from source
Running the tests
Update documentation
Internal APIs
API documentation
Public API: jax package
Subpackages
jax.numpy package
jax.scipy package
jax.experimental package
jax.experimental.loops module
jax.experimental.optimizers module
jax.experimental.optix module
jax.experimental.stax module
jax.lax package
jax.nn package
jax.ops package
jax.random package
jax.tree_util package
jax.dlpack module
jax.profiler module
Just-in-time compilation (
jit
)
Automatic differentiation
Vectorization (
vmap
)
Parallelization (
pmap
)
JAX
Docs
»
Public API: jax package
»
jax.experimental package
Edit on GitHub
jax.experimental package
¶
jax.experimental.loops module
jax.experimental.optimizers module
jax.experimental.optix module
jax.experimental.stax module
Read the Docs
v: test-docs
Versions
latest
stable
test-docs
Downloads
html
On Read the Docs
Project Home
Builds
Free document hosting provided by
Read the Docs
.