JEP 18137: Scope of JAX NumPy & SciPy Wrappers#

Jake VanderPlas

October 2023

Until now, the intended scope of jax.numpy and jax.scipy has been relatively ill-defined. This document proposes a well-defined scope for these packages to better guide and evaluate future contributions, and to motivate the removal of some out-of-scope code.

Background#

From the beginning, JAX has aimed to provide a NumPy-like API for executing code in XLA, and a big part of the project’s development has been building out the jax.numpy and jax.scipy namespaces as JAX-based implementations of NumPy and SciPy APIs. There has always been an implicit understanding that some parts of numpy and scipy are out-of-scope for JAX, but this scope has not been well defined. This can lead to confusion and frustration for contributors, because there’s no clear answer to whether potential jax.numpy and jax.scipy contributions will be accepted into JAX.

Why Limit the Scope?#

To avoid leaving this unsaid, we should be explicit: it is a fact that any code included in a project like JAX incurs a small but nonzero ongoing maintenance burden for the developers. The success of a project over time directly relates to the ability of maintainers to continue this maintenance for the sum of all the project’s parts: documenting functionality, responding to questions, fixing bugs, etc. For long-term success and sustainability of any software tool, it’s vital that maintainers carefully weigh whether any particular contribution will be a net positive for the project given its goals and resources.

Evaluation Rubric#

This document proposes a rubric of six axes along which the scope of any particular numpy or scipy API can be judged for inclusion into JAX. An API which is strong along all axes is an excellent candidate for inclusion in the JAX package; a strong weakness along any of the six axes is a good argument against inclusion in JAX.

Axis 1: XLA alignment#

The first axis we consider is the degree to which the proposed API aligns with native XLA operations. For example, jax.numpy.exp() is a function that more-or-less directly mirrors jax.lax.exp. A large number of functions in numpy, scipy.special, numpy.linalg, scipy.linalg, and others meet this criteria: such functions pass the XLA-alignment check when considering their inclusion into JAX.

On the other end, there are functions like numpy.unique(), which do not directly correspond to any XLA operation, and in some cases are fundamentally incompatible with JAX’s current computational model, which requires statically-shaped arrays (e.g. unique returns a value-dependent dynamic array shape). Such functions do not pass the XLA alignment check when considering their inclusion into JAX.

We also consider as part of this axis the need for pure function semantics. For example, numpy.random is built on an implicitly-updated state-based RNG, which is fundamentally incompatible with JAX’s computational model built on XLA.

Axis 2: Array API Alignment#

The second axis we consider focuses on the Python Array API Standard: this is in some senses a community-driven outline of which array operations are central to array-oriented programming across a wide range of user communities. If an API in numpy or scipy is listed within the Array API standard, it is a strong signal that JAX should include it. Using the example from above, the Array API standard includes several variants of numpy.unique() (unique_all, unique_counts, unique_inverse, unique_values) which suggests that, despite the function not being precisely aligned with XLA, it is important enough to the Python user community that JAX should perhaps implement it.

Axis 3: Existence of Downstream Implementations#

For functionality that does not align with Axis 1 or 2, an important consideration for inclusion into JAX is whether there exist well-supported downstream packages that supply the functionality in question. A good example of this is scipy.optimize: while JAX does include a minimal set of wrappers of scipy.optimize functionality, a much more complete treatment exists in the JAXopt package, which is actively maintained by JAX collaborators. In cases like this, we should lean toward pointing users and contributors to these specialized packages rather than re-implementing such APIs in JAX itself.

Axis 4: Complexity & Robustness of Implementation#

For functionality that does not align with XLA, one consideration is the degree of complexity of the proposed implementation. This aligns to some degree with Axis 1, but nevertheless is important to call out. A number of functions have been contributed to JAX which have relatively complex implementations which are difficult to validate and introduce outsized maintenance burdens; an example is jax.scipy.special.bessel_jn(): as of the writing of this JEP, its current implementation is a non-straightforward iterative approximation that has convergence issues in some domains, and proposed fixes introduce further complexity. Had we more carefully weighed the complexity and robustness of the implementation when accepting the contribution, we may have chosen not to accept this contribution to the package.

Axis 5: Functional vs. Object-Oriented APIs#

JAX works best with functional APIs rather than object-oriented APIs. Object-oriented APIs can often hide impure semantics, making them often difficult to implement well. NumPy and SciPy generally stick to functional APIs, but sometimes provide object-oriented convenience wrappers.

Examples of this are numpy.polynomial.Polynomial, which wraps lower-level operations like numpy.polyadd(), numpy.polydiv(), etc. In general, when there are both functional and object-oriented APIs available, JAX should avoid providing wrappers for the object-oriented APIs and instead provide wrappers for the functional APIs.

In cases where only the object-oriented APIs exist, JAX should avoid providing wrappers unless the case is strong along other axes.

Axis 6: General “Importance” to JAX Users & Stakeholders#

The decision to include a NumPy/SciPy API in JAX should also take into account the importance of the algorithm to the general user community. It is admittedly difficult to quantify who is a “stakeholder” and how this importance should be measured; but we include this to make clear that any decision about what to include in JAX’s NumPy and SciPy wrappers will involve some amount of discretion that cannot be easily quantified.

For existing APIs, searches for usage in github may be useful in establishing importance or lack thereof; as an example, we might return to jax.scipy.special.bessel_jn() discussed above: a search shows that this function has only a handful of uses on github, probably partly to do with the previously mentioned accuracy issues.

Evaluation: what’s in scope?#

In this section, we’ll attempt to evaluate the NumPy and SciPy APIs, including some examples from the current JAX API, in light of the above rubric. This will not be a comprehensive listing of all existing functions and classes, but rather a more general discussion by submodule and topic, with relevant examples.

NumPy APIs#

numpy namespace#

We consider the functions in the main numpy namespace to be essentially all in-scope for JAX, due to its general alignment with XLA (Axis 1) and the Python Array API (Axis 2), as well as its general importance to the JAX user community (Axis 6). Some functions are perhaps borderline (functions like numpy.intersect1d(), np.setdiff1d(), np.union1d() arguably fail parts of the rubric) but for simplicity we declare that all array functions in the main numpy namespace are in-scope for JAX.

numpy.linalg & numpy.fft#

The numpy.linalg and numpy.fft submodules contain many functions that broadly align with functionality provided by XLA. Others have complicated device-specific lowerings, but represent a case where importance to stakeholders (Axis 6) outweighs complexity. For this reason, we deem both of these submodules in-scope for JAX.

numpy.random#

numpy.random is out-of-scope for JAX, because state-based RNGs are fundamentally incompatible with JAX’s computation model. We instead focus on jax.random, which offers similar functionality using a counter-based PRNG.

numpy.ma & numpy.polynomial#

The numpy.ma and numpy.polynomial submodules are mostly concerned with providing object-oriented interfaces to computations that can be expressed via other functional means (Axis 5); for this reason, we deem them out-of-scope for JAX.

numpy.testing#

NumPy’s testing functionality only really makes sense for host-side computation, and so we don’t include any wrappers for it in JAX. That said, JAX arrays are compatible with numpy.testing, and JAX makes frequent use of it throughout the JAX test suite.

SciPy APIs#

SciPy has no functions in the top-level namespace, but includes a number of submodules. We consider each below, leaving out modules which have been deprecated.

scipy.cluster#

The scipy.cluster module includes tools for hierarchical clustering, k-means, and related algorithms. These are weak along several axes, and would be better served by a downstream package. One function already exists within JAX (jax.scipy.cluster.vq.vq()) but has no obvious usage on github: this suggests that clustering is not broadly important to JAX users.

Recommendation: deprecate and remove jax.scipy.cluster.vq().

scipy.constants#

The scipy.constants module includes mathematical and physical constants. These constants can be used directly with JAX, and so there is no reason to re-implement this in JAX.

scipy.datasets#

The scipy.datasets module includes tools to fetch and load datasets. These fetched datasets can be used directly with JAX, and so there is no reason to re-implement this in JAX.

scipy.fft#

The scipy.fft module contains functions that broadly align with functionality provided by XLA, and fare well along other axes as well. For this reason, we deem them in-scope for JAX.

scipy.integrate#

The scipy.integrate module contains functions for numerical integration. The more sophisticated of these (quad, dblquad, ode) are out-of-scope for JAX by axes 1 & 4, since they tend to be loopy algorithms based on dynamic numbers of evaluations. jax.experimental.ode.odeint() is related, but rather limited and not under any active development.

JAX does currently include jax.scipy.integrate.trapezoid(), but this is only because numpy.trapz() was recently deprecated in favor of this. For any particular input, its implementation could be replaced with one line of jax.numpy expressions, so it’s not a particularly useful API to provide.

Based on Axes 1, 2, 4, and 6, scipy.integrate should be considered out-of-scope for JAX.

Recommendation: remove jax.scipy.integrate.trapezoid(), which was added in JAX 0.4.14.

scipy.interpolate#

The scipy.interpolate module provides both low-level and object-oriented routines for interpolating in one or more dimensions. These APIs rate poorly along a number of the axes above: they are class-based rather than low-level, and none but the simplest methods can be expressed efficiently in terms of XLA operations.

JAX does currently have wrappers for scipy.interpolate.RegularGridInterpolator. Were we considering this contribution today, we would probably reject it by the above criteria. But this code has been fairly stable so there is not much downside to continuing to maintain it.

Going forward, we should consider other members of scipy.interpolate to be out-of-scope for JAX.

scipy.io#

The scipy.io submodule has to do with file input/output. There is no reason to re-implement this in JAX.

scipy.linalg#

The scipy.linalg submodule contains functions that broadly align with functionality provided by XLA, and fast linear algebra is broadly important to the JAX user community. For this reason, we deem it in-scope for JAX.

scipy.ndimage#

The scipy.ndimage submodule contains a set of tools for working on image data. Many of these overlap with tools in scipy.signal (e.g. convolutions and filtering). JAX currently provides one scipy.ndimage API, in jax.scipy.ndimage.map_coordinates(). Additionally, JAX provides some image-related tools in the jax.image module. The deepmind ecosystem includes dm-pix, a more full-featured set of tools for image manipulation in JAX. Given all these factors, I’d suggest that scipy.ndimage should be considered out-of-scope for JAX core; we can point interested users and contributors to dm-pix. We can consider moving map_coordinates to dm-pix or to another appropriate package.

scipy.odr#

The scipy.odr module provides an object-oriented wrapper around ODRPACK for performing orthogonal distance regressions. It is not clear that this could be cleanly expressed using existing JAX primitives, and so we deem it out of scope for JAX itself.

scipy.optimize#

The scipy.optimize module provides high-level and low-level interfaces for optimization. Such functionality is important to a lot of JAX users, and very early on JAX created jax.scipy.optimize wrappers. However, developers of these routines soon realized that the scipy.optimize API was too constraining, and different teams began working on the JAXopt package and the Optimistix package, each of which contain a much more comprehensive and better-tested set of optimization routines in JAX.

Because of these well-supported external packages, we now consider scipy.optimize to be out-of-scope for JAX.

Recommendation: deprecate jax.scipy.optimize and/or make it a lightweight wrapper around an optional JAXopt or Optimistix dependency.

🟡 scipy.signal#

The scipy.signal module is mixed: some functions are squarely in-scope for JAX (e.g. correlate and convolve, which are more user-friendly wrappers of lax.conv_general_dilated), while many others are squarely out-of-scope (domain-specific tools with no viable lowering path to XLA). Potential contributions to jax.scipy.signal will have to be weighed on a case-by-case basis.

🟡 scipy.sparse#

The scipy.sparse submodule mainly contains data structures for storing and operating on sparse matrices and arrays in a variety of formats. Additionally, scipy.sparse.linalg contains a number of matrix-free solvers, suitable for use with sparse matrices, dense matrices, and linear operators.

The scipy.sparse array and matrix data structures are out-of-scope for JAX, because they do not align with JAX’s computational model (e.g. many operations depend on dynamically-sized buffers). JAX has developed the jax.experimental.sparse module as an alternative set of data structures that are more in-line with JAX’s computational constraints. For these reasons, we consider the data structures in scipy.sparse to be out-of-scope for JAX.

On the other hand, scipy.sparse.linalg has proven to be an interesting area, and jax.scipy.sparse.linalg includes the bicgstab, cg, and gmres solvers. These are useful to the JAX user community (Axis 6) but aside from this do not fare well along other axes. They would be very suitable for moving into a downstream library; one potential option may be Lineax, which features a number of linear solvers built on JAX.

Recommendation: explore moving sparse solvers into Lineax, and otherwise treat `scipy.sparse`` as out-of-scope for JAX.

scipy.spatial#

The scipy.spatial module contains mainly object-oriented interfaces to spatial/distance computations and nearest neighbor searches. It is mostly out-of-scope for JAX

The scipy.spatial.transform submodule provides tools for manipulating three-dimensional spatial rotations. It is a relatively complicated object-oriented interface, and could perhaps be better served by a downstream project. JAX currently contains partial implementations of Rotation and Slerp within jax.scipy.spatial.transform; these are object-oriented wrappers of otherwise basic functions, which introduce a very large API surface and have very few users. It is our judgment that they are out-of-scope for JAX itself, with users better-served by a hypothetical downstream project.

The scipy.spatial.distance submodule contains a useful collection of distance metrics, and it might be tempting to provide JAX wrappers for these. That said, with jit and vmap it would be straightforward for a user to define efficient versions of most of these from scratch if needed, so adding them to JAX is not particularly beneficial.

Recommendation: consider deprecating and removing the Rotation and Slerp APIs, and consider scipy.spatial as a whole out-of-scope for future contributions.

scipy.special#

The scipy.special module includes implementations of a number of more specialized functions. In many cases, these functions are squarely in scope: for example, functions like gammaln, betainc, digamma, and many others correspond directly to available XLA primitives, and are clearly in-scope by Axis 1 and others.

Other functions require more complicated implementations; one example mentioned above is bessel_jn. Despite not aligning on Axes 1 and 2, these functions tend to be very strong along Axis 6: scipy.special provides fundamental functions necessary for computation in a variety of domains, so even functions with complicated implementations should lean toward in-scope, so long as the implementations are well-designed and robust.

There are a few existing function wrappers that we should take a closer look at; for example:

  • jax.scipy.special.lpmn(): this generates legendre polynomials via a complicated fori_loop, in a way that does not match the scipy API (e.g. for scipy, z must be a scalar, while for JAX, z must be a 1D array). The function has few discoverable uses making it a weak candidate along Axes 1, 2, 4, and 6.

  • jax.scipy.special.lpmn_values(): this has similar weaknesses to lmpn above.

  • jax.scipy.special.sph_harm(): this is built on lpmn, and similarly has an API that diverges from the corresponding scipy function.

  • jax.scipy.special.bessel_jn(): as discussed under Axis 4 above, this has weaknesses in terms of robustness of implementation and little usage. We might consider replacing it with a new, more robust implementation (e.g. #17038).

Recommendation: refactor and improve robustness & test coverage for bessel_jn. Consider deprecating lpmn, lpmn_values, and sph_harm if they cannot be modified to more closely match the scipy APIs.

scipy.stats#

The scipy.stats module contains a wide range of statistical functions, including discrete and continuous distributions, summary statistics, and hypothesis testing. JAX currently wraps a number of these in jax.scipy.stats, primarily including 20 or so statistical distributions, along with a few other functions (mode, rankdata, gaussian_kde). In general these are well-aligned with JAX: distributions usually are expressible in terms of efficient XLA operations, and the APIs are clean and functional.

We don’t currently have any wrappers for hypothesis testing functions, probably because these are less useful to the primary user-base of JAX.

Regarding distributions, in some cases, tensorflow_probability provides similar functionality, and in the future we might consider whether to deprecate the scipy.stats distributions in favor of that implementation.

Recommendation: going forward, we should treat statistical distributions and summary statistics as in-scope, and consider hypothesis tests and related functionality generally out-of-scope.