jax.extend: a module for extensions#

@froystig, @sharadmv, @jakevdp, @yashk2810

May 2023

import jax.extend as jex

Several projects depend on JAX’s codebase internals, often to use its core machinery (e.g. to write a transformation over its IR) or to extend it (e.g. to define new primitives). Two challenges for these dependencies are (a) that our internals aren’t all solidly designed for external use, and (b) that circumventing JAX’s public API is unsupported. In other words, our internals are often used like a library, but are neither structured nor updated like one.

This proposal considers introducing a jax.extend module that defines a library view of some of JAX’s internal components. We would treat this as a second-tier API, still guaranteeing essentially no compatibility policy, but hopefully making it easier to spot changes when they happen.

The audience for jax.extend includes JAX-adjacent Python libraries like Oryx, jax-triton, and many others, as well as projects experimenting with function transformations, autodiff systems, compiler frontends for numerical programming, etc.

This note gives an overview of how jax.extend might look, now and eventually. It doesn’t lay things out in great detail, instead proposing that we begin iteratively developing the module.

Note that jax.extend differs from jax.experimental, which is a staging ground for new features and ideas in progress. Typically, work in jax.experimental eventually makes into another JAX module or is removed altogether.

No compatibility policy#

To keep development overhead low, jax.extend would not follow the public API compatibility policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the changelog to call out such changes.

Callers of jax.extend that need to upgrade their code regularly alongside JAX releases might find it useful to pin JAX versions as an intermediate step between releases. This is a common habit among projects that rely on JAX’s internals today. The difference is that it would now come with the help of changelog announcements and better intentions regarding library design and naming.

Iterative development#

Having no compatibility policy makes it easier to get started on implementation: on day one, we can move a handful of symbols over from internal packages such as jax._src and today’s jax.core and jax.interpreters. Then we can iterate to improve things from there.

Possible module overview#

We can imagine that eventually jax.extend would include the following modules:

  • core – primitives, the Jaxpr IR, etc.

  • interpreters – core transformations (e.g. autodiff, batching) and lowerings.

  • random – random bit generation, key splitting and folding, key arrays.

  • sharding – extra functionality around distributed arrays.

We might also have other symbols in the module at first, such as jex.api_util, as we work to remove or replace them. Others will be decided in time. For instance, jex.lib could offer an entry point to jaxlib (and would do so in the immediate term), but it’s not clear whether we want to keep it for long.

Some preliminary thoughts on what each of these might comprise follow.

jax.extend.core#

This should enable callers at least to define new JAX primitives and to process the Jaxpr IR (the output of jax.make_jaxpr(...)). Supporting this might involve providing:

  • Access to existing core system primitives, such as today’s jax._src.lax.add_p.

  • Access to IR types, such as the current jax._src.core.ShapedArray.

  • Functions for checking and pretty-printing jaxprs.

  • Functions for building jaxprs explicitly, rather than by staging Python functions via jax.make_jaxpr (or not!).

At initialization, this module will contain many more symbols than what’s needed to define primitives and rules, including various names used in setting up “final-style transformations”, such as the current jax._src.core.Trace and Tracer classes. We can revisit whether jex.core should also support final-style extensions alongside initial style approaches, and whether it can do so by a more narrow API than exposing Trace and Tracer entirely. Oryx might help guide these decisions.

We can also consider relocating make_jaxpr itself to jex.core.

jax.extend.interpreters#

This module would provide a means of registering various transformation rules for primitives—defining their behavior under AD, batching, lowering, etc.

It would initially reflect jax._src.interpreters in providing the modules ad, batching, partial_eval (for staging Python to Jaxpr, and for linearization in AD), mlir, pxla, and xla. The first three might be replaceable by a single primitive extension API in jex.core. The latter three, used for lowering, could be simplified into one module, maybe.

Today, to write transformation rules, e.g. for AD and batching, callers may need symbols relating to tracers, e.g. JVPTracer and BatchTracer. This may be avoidable later on, and allow us to remove tracer types from jex.

This module plus jex.core ought to suffice for replicating today’s custom primitive tutorials (e.g. ours and dfm’s). For instance, defining a primitive and its behavior under jax.jit would be possible as follows (in the immediate term):

from jax.extend import core	         # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly

mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)

@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)

def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results

mlir.register_lowering(mul_add_p, mul_add_mlir)

import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)

jax.extend.random#

This module could expose our mechanism for defining new RNG implementations, and functions for working with PRNG key internals (see issue #9263), such as the current jax._src.prng.random_wrap and random_unwrap.

It could also expose the keyed hash functions that underlie the built-in RNG implementations, such as jax._src.prng.threefry_2x32.

jax.extend.sharding#

This module could expose low-level utilities for sharding distributed arrays.

We have only one item in mind for now. The XLA compiler’s array sharding format is more expressive than those provided by JAX. We could provide this as jex.sharding.XlaOpShardingProto, corresponding to today’s jax._src.lib.xla_client.OpSharding internally.