Building from source
Building from source#
First, obtain the JAX source code:
git clone https://github.com/google/jax cd jax
Building JAX involves two steps:
Building or installing
jaxlib, the C++ support library for
Building or installing
jaxlib with pip#
If you’re only modifying Python portions of JAX, we recommend installing
jaxlib from a prebuilt wheel using pip:
pip install jaxlib
See the JAX readme for full guidance on pip installation (e.g., for GPU and TPU support).
jaxlib from source#
jaxlib from source, you must also install some prerequisites:
a C++ compiler (g++, clang, or MSVC)
On Ubuntu or Debian you can install the necessary prerequisites with:
sudo apt install g++ python python3-dev
If you are building on a Mac, make sure XCode and the XCode command line tools are installed.
See below for Windows build instructions.
You can install the necessary Python dependencies using
pip install numpy wheel
jaxlib without CUDA GPU or TPU support (CPU only), you can run:
python build/build.py pip install dist/*.whl # installs jaxlib (includes XLA)
jaxlib with CUDA support, use
python build/build.py --enable_cuda;
to build with TPU support, use
python build/build.py --enable_tpu.
python build/build.py --help for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. Here
python should be the name of your Python 3 interpreter; on some systems, you
may need to use
python3 instead. By default, the wheel is written to the
dist/ subdirectory of the current directory.
Additional Notes for Building
jaxlib from source on Windows#
On Windows, follow Install Visual Studio to set up a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required. If you need to build with CUDA enabled, follow the CUDA Installation Guide to set up a CUDA environment.
JAX builds use symbolic links, which require that you activate Developer Mode.
pacman -S patch coreutils
Once coreutils is installed, the realpath command should be present in your shell’s path.
Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
path of the current session. Ensure
accessible. Activate the conda environment. The following command builds with
CUDA enabled, adjust it to whatever suitable for you:
python .\build\build.py ` --enable_cuda ` --cuda_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' ` --cudnn_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' ` --cuda_version='10.1' ` --cudnn_version='7.6.5'
To build with debug information, add the flag
jaxlib has been installed, you can install
jax by running:
pip install -e . # installs jax
To upgrade to the latest version from GitHub, just run
git pull from the JAX
repository root, and rebuild by running
build.py or upgrading
necessary. You shouldn’t have to reinstall
pip install -e
sets up symbolic links from site-packages into the repository.
Running the tests#
To run all the JAX tests, we recommend using
pytest-xdist, which can run tests in
parallel. First, install
pytest-benchmark by running
pip install -r build/test-requirements.txt.
Then, from the repository root directory run:
pytest -n auto tests
JAX generates test cases combinatorially, and you can control the number of cases that are generated and checked for each test (default is 10). The automated tests currently use 25:
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
The automated tests also run the tests with default 64-bit floats and ints:
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
You can run a more specific set of tests using pytest’s built-in selection mechanisms, or alternatively you can run a specific test file directly to see more detailed information about the cases being run:
python tests/lax_numpy_test.py --num_generated_cases=5
You can skip a few tests known as slow, by passing environment variable JAX_SKIP_SLOW_TESTS=1.
To specify a particular set of tests to run from a test file, you can pass a string
or regular expression via the
--test_targets flag. For example, you can run all
the tests of
python tests/lax_numpy_test.py --test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
Note that to run the full pmap tests on a (multi-core) CPU-only machine, you can run:
I.e. don’t use the
-n auto option, since that effectively runs each test on a
JAX uses pytest in doctest mode to test the code examples within the documentation. You can run this using
Additionally, JAX runs pytest in
doctest-modules mode to ensure code examples in
function docstrings will run correctly. You can run this locally using, for example:
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
mypy to check the type hints. To check types locally the same way
as the CI checks them:
pip install mypy mypy --config=mypy.ini --show-error-codes jax
Alternatively, you can use the pre-commit framework to run this on all staged files in your git repository, automatically using the same mypy version as in the GitHub CI:
pre-commit run mypy
JAX uses the flake8 linter to ensure code quality. You can check your local changes by running:
pip install flake8 flake8 jax
Alternatively, you can use the pre-commit framework to run this on all staged files in your git repository, automatically using the same flake8 version as the GitHub tests:
pre-commit run flake8
To rebuild the documentation, install several packages:
pip install -r docs/requirements.txt
And then run:
sphinx-build -b html docs docs/build/html -j auto
This can take a long time because it executes many of the notebooks in the documentation source; if you’d prefer to build the docs without executing the notebooks, you can run:
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
You can then see the generated documentation in
-j auto option controls the parallelism of the build. You can use a number
in place of
auto to control how many CPU cores to use.
We use jupytext to maintain two synced copies of the notebooks
docs/notebooks: one in
ipynb format, and one in
md format. The advantage of the former
is that it can be opened and executed directly in Colab; the advantage of the latter is that
it makes it much easier to track diffs within version control.
For making large changes that substantially modify code and outputs, it is easiest to
edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface,
open http://colab.research.google.com and
Upload from your local repo.
Update it as needed,
Run all cells then
You may want to test that it executes properly, using
sphinx-build as explained above.
For making smaller changes to the text content of the notebooks, it is easiest to edit the
.md versions using a text editor.
After editing either the ipynb or md versions of the notebooks, you can sync the two versions
using jupytext by running
jupytext --sync on the updated
notebooks; for example:
pip install jupytext==1.13.8 jupytext --sync docs/notebooks/quickstart.ipynb
The jupytext version should match that specified in .pre-commit-config.yaml.
To check that the markdown and ipynb files are properly synced, you may use the pre-commit framework to perform the same check used by the github CI:
git add docs -u # pre-commit runs on files in git staging. pre-commit run jupytext
Creating new notebooks#
If you are adding a new notebook to the documentation and would like to use the
command discussed here, you can set up your notebook for jupytext by using the following command:
jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
This works by adding a
"jupytext" metadata field to the notebook file which specifies the
desired formats, and which the
jupytext --sync command recognizes when invoked.
Notebooks within the sphinx build#
Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the Read the docs build.
The build will fail if cells raise errors. If the errors are intentional, you can either catch them,
or tag the cell with
raises-exceptions metadata (example PR).
You have to add this metadata by hand in the
.ipynb file. It will be preserved when somebody else
re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations.
exclude_patterns in conf.py.
Documentation building on readthedocs.io#
JAX’s auto-generated documentation is at https://jax.readthedocs.io/.
The documentation building is controlled for the entire project by the
readthedocs JAX settings. The current settings
trigger a documentation build as soon as code is pushed to the GitHub
For each code version, the building process is driven by the
.readthedocs.yml and the
docs/conf.py configuration files.
For each automated documentation build you can see the documentation build logs.
If you want to test the documentation generation on Readthedocs, you can push code to the
branch. That branch is also built automatically, and you can
see the generated documentation here. If the documentation build
fails you may want to wipe the build environment for test-docs.
For a local test, I was able to do it in a fresh directory by replaying the commands I saw in the Readthedocs logs:
mkvirtualenv jax-docs # A new virtualenv mkdir jax-docs # A new directory cd jax-docs git clone --no-single-branch --depth 50 https://github.com/google/jax cd jax git checkout --force origin/test-docs git clean -d -f -f workon jax-docs python -m pip install --upgrade --no-cache-dir pip python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1' python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt cd docs python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html