jax.extend: a module for extensions#
import jax.extend as jex
Several projects depend on JAX’s codebase internals, often to use its core machinery (e.g. to write a transformation over its IR) or to extend it (e.g. to define new primitives). Two challenges for these dependencies are (a) that our internals aren’t all solidly designed for external use, and (b) that circumventing JAX’s public API is unsupported. In other words, our internals are often used like a library, but are neither structured nor updated like one.
This proposal considers introducing a
jax.extend module that
defines a library view of some of JAX’s internal components. We would
treat this as a second-tier API, still guaranteeing essentially no
compatibility policy, but hopefully making
it easier to spot changes when they happen.
The audience for
jax.extend includes JAX-adjacent Python libraries
jax-triton, and many others,
as well as projects experimenting with function transformations,
autodiff systems, compiler frontends for numerical programming, etc.
This note gives an overview of how
jax.extend might look, now and
eventually. It doesn’t lay things out in great detail, instead
proposing that we begin iteratively developing
jax.extend differs from
jax.experimental, which is a
staging ground for new features and ideas in progress. Typically, work
jax.experimental eventually makes into another JAX module or is
No compatibility policy#
To keep development overhead low,
jax.extend would not follow the
policy. It would promise no deprecation windows nor backwards
compatibility between releases. Every release may break existing
callers without simple recourse (e.g. without a flag reintroducing
prior behavior). We would rely on the
to call out such changes.
jax.extend that need to upgrade their code regularly
alongside JAX releases might find it useful to pin JAX versions as an
intermediate step between releases. This is a common habit among
projects that rely on JAX’s internals today. The difference is that it
would now come with the help of changelog announcements and better
intentions regarding library design and naming.
Having no compatibility policy makes it easier to get started on
implementation: on day one, we can move a handful of symbols over from
internal packages such as
jax._src and today’s
jax.interpreters. Then we can iterate to improve things from there.
Possible module overview#
We can imagine that eventually
jax.extend would include the
core– primitives, the Jaxpr IR, etc.
interpreters– core transformations (e.g. autodiff, batching) and lowerings.
random– random bit generation, key splitting and folding, key arrays.
sharding– extra functionality around distributed arrays.
We might also have other symbols in the module at first, such as
jex.api_util, as we work to remove or replace them. Others will be
decided in time. For instance,
jex.lib could offer an entry point to
jaxlib (and would do so in the immediate term), but it’s not clear
whether we want to keep it for long.
Some preliminary thoughts on what each of these might comprise follow.
This should enable callers at least to define new JAX primitives and
to process the Jaxpr IR (the output of
jax.make_jaxpr(...)). Supporting this might involve providing:
Access to existing core system primitives, such as today’s
Access to IR types, such as the current
Functions for checking and pretty-printing jaxprs.
Functions for building jaxprs explicitly, rather than by staging Python functions via
At initialization, this module will contain many more symbols than
what’s needed to define primitives and rules, including various names
used in setting up
such as the current
Tracer classes. We can
jex.core should also support final-style extensions
alongside initial style approaches, and whether it can do so by a more
narrow API than exposing
Oryx might help guide these decisions.
We can also consider relocating
make_jaxpr itself to
This module would provide a means of registering various transformation rules for primitives—defining their behavior under AD, batching, lowering, etc.
It would initially reflect
jax._src.interpreters in providing
partial_eval (for staging Python to
Jaxpr, and for linearization in AD),
first three might be replaceable by a single primitive extension API
jex.core. The latter three, used for lowering, could be
simplified into one module, maybe.
Today, to write transformation rules, e.g. for AD and batching,
callers may need symbols relating to tracers, e.g.
BatchTracer. This may be avoidable later on, and allow us to remove
tracer types from
This module plus
jex.core ought to suffice for replicating today’s
custom primitive tutorials (e.g.
For instance, defining a primitive and its behavior under
would be possible as follows (in the immediate term):
from jax.extend import core # Previously: from jax import core from jax.extend.interpreters import mlir # ... and similarly mul_add_p = core.Primitive('mul_add') mul_add_p.def_impl(lambda x, y, z: x * y + z) @mul_add_p.def_abstract_eval def mul_add_abstract(x_sa, y_sa, z_sa): return core.ShapedArray(x_sa.shape, x_sa.dtype) def mul_add_mlir(ctx, xc, yc, zc): add = mlir.hlo.AddOp mul = mlir.hlo.MulOp return add(mul(xc, yc), zc).results mlir.register_lowering(mul_add_p, mul_add_mlir) import jax print(mul_add_p.bind(2, 3, 4)) # -> 10 print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals
(see issue #9263),
such as the current
It could also expose the keyed hash functions that underlie the
built-in RNG implementations, such as
This module could expose low-level utilities for sharding distributed arrays.
We have only one item in mind for now. The XLA compiler’s
array sharding format is more expressive than those provided by
JAX. We could
provide this as
jex.sharding.XlaOpShardingProto, corresponding to