JEP 18137: Scope of JAX NumPy & SciPy Wrappers#
Until now, the intended scope of
jax.scipy has been relatively
ill-defined. This document proposes a well-defined scope for these packages to better guide
and evaluate future contributions, and to motivate the removal of some out-of-scope code.
From the beginning, JAX has aimed to provide a NumPy-like API for executing code in XLA,
and a big part of the project’s development has been building out the
jax.scipy namespaces as JAX-based implementations of NumPy and SciPy APIs. There has always
been an implicit understanding that some parts of
scipy are out-of-scope
for JAX, but this scope has not been well defined. This can lead to confusion and frustration
for contributors, because there’s no clear answer to whether potential
jax.scipy contributions will be accepted into JAX.
Why Limit the Scope?#
To avoid leaving this unsaid, we should be explicit: it is a fact that any code included in a project like JAX incurs a small but nonzero ongoing maintenance burden for the developers. The success of a project over time directly relates to the ability of maintainers to continue this maintenance for the sum of all the project’s parts: documenting functionality, responding to questions, fixing bugs, etc. For long-term success and sustainability of any software tool, it’s vital that maintainers carefully weigh whether any particular contribution will be a net positive for the project given its goals and resources.
This document proposes a rubric of six axes along which the scope of any particular
scipy API can be judged for inclusion into JAX. An API which is strong along all axes
is an excellent candidate for inclusion in the JAX package; a strong weakness along any of
the six axes is a good argument against inclusion in JAX.
Axis 1: XLA alignment#
The first axis we consider is the degree to which the proposed API aligns with native XLA
operations. For example,
jax.numpy.exp() is a function that more-or-less directly mirrors
jax.lax.exp. A large number of functions in
scipy.linalg, and others meet this criteria: such functions pass the XLA-alignment check
when considering their inclusion into JAX.
On the other end, there are functions like
numpy.unique(), which do not directly correspond
to any XLA operation, and in some cases are fundamentally incompatible with JAX’s current
computational model, which requires statically-shaped arrays (e.g.
unique returns a
value-dependent dynamic array shape). Such functions do not pass the XLA alignment check
when considering their inclusion into JAX.
We also consider as part of this axis the need for pure function semantics. For example,
numpy.random is built on an implicitly-updated state-based RNG, which is fundamentally
incompatible with JAX’s computational model built on XLA.
Axis 2: Array API Alignment#
The second axis we consider focuses on the
Python Array API Standard: this is in some
senses a community-driven outline of which array operations are central to array-oriented
programming across a wide range of user communities. If an API in
listed within the Array API standard, it is a strong signal that JAX should include it.
Using the example from above, the Array API standard includes several variants of
suggests that, despite the function not being precisely aligned with XLA, it is important
enough to the Python user community that JAX should perhaps implement it.
Axis 3: Existence of Downstream Implementations#
For functionality that does not align with Axis 1 or 2, an important consideration for
inclusion into JAX is whether there exist well-supported downstream packages that supply
the functionality in question. A good example of this is
scipy.optimize: while JAX does
include a minimal set of wrappers of
scipy.optimize functionality, a much more complete
treatment exists in the JAXopt package, which is actively
maintained by JAX collaborators. In cases like this, we should lean toward pointing users
and contributors to these specialized packages rather than re-implementing such APIs in
Axis 4: Complexity & Robustness of Implementation#
For functionality that does not align with XLA, one consideration is the degree of
complexity of the proposed implementation. This aligns to some degree with Axis 1,
but nevertheless is important to call out. A number of functions have been contributed
to JAX which have relatively complex implementations which are difficult to validate
and introduce outsized maintenance burdens; an example is
as of the writing of this JEP, its current implementation is a non-straightforward
iterative approximation that has
convergence issues in some domains,
and proposed fixes introduce further
complexity. Had we more carefully weighed the complexity and robustness of the
implementation when accepting the contribution, we may have chosen not to accept this
contribution to the package.
Axis 5: Functional vs. Object-Oriented APIs#
JAX works best with functional APIs rather than object-oriented APIs. Object-oriented APIs can often hide impure semantics, making them often difficult to implement well. NumPy and SciPy generally stick to functional APIs, but sometimes provide object-oriented convenience wrappers.
Examples of this are
numpy.polynomial.Polynomial, which wraps lower-level operations
numpy.polydiv(), etc. In general, when there are both functional
and object-oriented APIs available, JAX should avoid providing wrappers for the
object-oriented APIs and instead provide wrappers for the functional APIs.
In cases where only the object-oriented APIs exist, JAX should avoid providing wrappers unless the case is strong along other axes.
Axis 6: General “Importance” to JAX Users & Stakeholders#
The decision to include a NumPy/SciPy API in JAX should also take into account the importance of the algorithm to the general user community. It is admittedly difficult to quantify who is a “stakeholder” and how this importance should be measured; but we include this to make clear that any decision about what to include in JAX’s NumPy and SciPy wrappers will involve some amount of discretion that cannot be easily quantified.
For existing APIs, searches for usage in github may be useful in establishing importance
or lack thereof; as an example, we might return to
discussed above: a search shows that this function has only a
handful of uses
on github, probably partly to do with the previously mentioned accuracy issues.
Evaluation: what’s in scope?#
In this section, we’ll attempt to evaluate the NumPy and SciPy APIs, including some examples from the current JAX API, in light of the above rubric. This will not be a comprehensive listing of all existing functions and classes, but rather a more general discussion by submodule and topic, with relevant examples.
We consider the functions in the main
numpy namespace to be essentially all in-scope
for JAX, due to its general alignment with XLA (Axis 1) and the Python Array API
(Axis 2), as well as its general importance to the JAX user community (Axis 6).
Some functions are perhaps borderline (functions like
np.union1d() arguably fail parts of the rubric) but for
simplicity we declare that all array functions in the main numpy namespace are in-scope
numpy.fft submodules contain many functions that
broadly align with functionality provided by XLA. Others have complicated device-specific
lowerings, but represent a case where importance to stakeholders (Axis 6) outweighs complexity.
For this reason, we deem both of these submodules in-scope for JAX.
numpy.random is out-of-scope for JAX, because state-based RNGs are fundamentally
incompatible with JAX’s computation model. We instead focus on
which offers similar functionality using a counter-based PRNG.
numpy.polynomial submodules are mostly concerned with
providing object-oriented interfaces to computations that can be expressed via other
functional means (Axis 5); for this reason, we deem them out-of-scope for JAX.
NumPy’s testing functionality only really makes sense for host-side computation,
and so we don’t include any wrappers for it in JAX. That said, JAX arrays are
numpy.testing, and JAX makes frequent use of it throughout
the JAX test suite.
SciPy has no functions in the top-level namespace, but includes a number of submodules. We consider each below, leaving out modules which have been deprecated.
scipy.cluster module includes tools for hierarchical clustering, k-means,
and related algorithms. These are weak along several axes, and would be better
served by a downstream package. One function already exists within JAX
jax.scipy.cluster.vq.vq()) but has
no obvious usage
on github: this suggests that clustering is not broadly important to JAX users.
Recommendation: deprecate and remove
scipy.constants module includes mathematical and physical constants.
These constants can be used directly with JAX, and so there is no reason to
re-implement this in JAX.
scipy.datasets module includes tools to fetch and load datasets.
These fetched datasets can be used directly with JAX, and so there is no
reason to re-implement this in JAX.
scipy.fft module contains functions that broadly align with functionality
provided by XLA, and fare well along other axes as well. For this reason,
we deem them in-scope for JAX.
scipy.integrate module contains functions for numerical integration. The
more sophisticated of these (
ode) are out-of-scope for JAX by
axes 1 & 4, since they tend to be loopy algorithms based on dynamic numbers of
jax.experimental.ode.odeint() is related, but rather limited and not
under any active development.
JAX does currently include
jax.scipy.integrate.trapezoid(), but this is only because
numpy.trapz() was recently deprecated in favor of this. For any particular inputs,
its implementation could be replaced with one line of
jax.numpy expressions, so
it’s not a particularly useful API to provide.
Based on Axes 1, 2, 4, and 6,
scipy.integrate should be considered out-of-scope for JAX.
jax.scipy.integrate.trapezoid(), which was added in JAX 0.4.14.
scipy.interpolate module provides both low-level and object-oriented routines
for interpolating in one or more dimensions. These APIs rate poorly along a number
of the axes above: they are class-based rather than low-level, and none but the
simplest methods can be expressed efficiently in terms of XLA operations.
JAX does currently have wrappers for
Were we considering this contribution today, we would probably reject it by the
above criteria. But this code has been fairly stable so there is not much downside
to continuing to maintain it.
Going forward, we should consider other members of
scipy.interpolate to be
out-of-scope for JAX.
scipy.io submodule has to do with file input/output. There is no reason
to re-implement this in JAX.
scipy.linalg submodule contains functions that broadly align with functionality
provided by XLA, and fast linear algebra is broadly important to the JAX user community.
For this reason, we deem it in-scope for JAX.
scipy.ndimage submodule contains a set of tools for working on image data. Many
of these overlap with tools in
scipy.signal (e.g. convolutions and filtering). JAX
currently provides one
scipy.ndimage API, in
Additionally, JAX provides some image-related tools in the
jax.image module. The
deepmind ecosystem includes dm-pix, a
more full-featured set of tools for image manipulation in JAX. Given all these factors,
I’d suggest that
scipy.ndimage should be considered out-of-scope for JAX core; we can
point interested users and contributors to dm-pix. We can consider moving
dm-pix or to another appropriate package.
scipy.odr module provides an object-oriented wrapper around
performing orthogonal distance regressions. It is not clear that this could be cleanly
expressed using existing JAX primitives, and so we deem it out of scope for JAX itself.
scipy.optimize module provides high-level and low-level interfaces for optimization.
Such functionality is important to a lot of JAX users, and very early on JAX created
jax.scipy.optimize wrappers. However, developers of these routines soon realized that
scipy.optimize API was too constraining, and different teams began working on the
JAXopt package and the
Optimistix package, each of which contain
a much more comprehensive and better-tested set of optimization routines in JAX.
Because of these well-supported external packages, we now consider
to be out-of-scope for JAX.
jax.scipy.optimize and/or make it a lightweight wrapper
around an optional JAXopt or Optimistix dependency.
scipy.signal module is mixed: some functions are squarely in-scope for JAX
convolve, which are more user-friendly wrappers of
lax.conv_general_dilated), while many others are squarely out-of-scope (domain-specific
tools with no viable lowering path to XLA). Potential contributions to
will have to be weighed on a case-by-case basis.
scipy.sparse submodule mainly contains data structures for storing and operating
on sparse matrices and arrays in a variety of formats. Additionally,
contains a number of matrix-free solvers, suitable for use with sparse matrices,
dense matrices, and linear operators.
scipy.sparse array and matrix data structures are out-of-scope for JAX, because
they do not align with JAX’s computational model (e.g. many operations depend on
dynamically-sized buffers). JAX has developed the
as an alternative set of data structures that are more in-line with JAX’s computational
constraints. For these reasons, we consider the data structures in
be out-of-scope for JAX.
On the other hand,
scipy.sparse.linalg has proven to be an interesting area, and
jax.scipy.sparse.linalg includes the
gmres solvers. These
are useful to the JAX user community (Axis 6) but aside from this do not fare well
along other axes. They would be very suitable for moving into a downstream library;
one potential option may be Lineax, which features
a number of linear solvers built on JAX.
Recommendation: explore moving sparse solvers into Lineax, and otherwise treat `scipy.sparse`` as out-of-scope for JAX.
scipy.spatial module contains mainly object-oriented interfaces to spatial/distance
computations and nearest neighbor searches. It is mostly out-of-scope for JAX
scipy.spatial.transform submodule provides tools for manipulating three-dimensional
spatial rotations. It is a relatively complicated object-oriented interface, and could
perhaps be better served by a downstream project. JAX currently contains partial
these are object-oriented wrappers of otherwise basic
functions, which introduce a very large API surface and have very few users. It is our
judgment that they are out-of-scope for JAX itself, with users better-served by a
hypothetical downstream project.
scipy.spatial.distance submodule contains a useful collection of distance metrics,
and it might be tempting to provide JAX wrappers for these. That said, with jit and vmap
it would be straightforward for a user to define efficient versions of most of these from
scratch if needed, so adding them to JAX is not particularly beneficial.
Recommendation: consider deprecating and removing the
Slerp APIs, and
scipy.spatial as a whole out-of-scope for future contributions.
scipy.special module includes implementations of a number of more specialized
functions. In many cases, these functions are squarely in scope: for example, functions
digamma, and many others correspond directly to available
XLA primitives, and are clearly in-scope by Axis 1 and others.
Other functions require more complicated implementations; one example mentioned above
bessel_jn. Despite not aligning on Axes 1 and 2, these functions tend to be very
strong along Axis 6:
scipy.special provides fundamental functions necessary for
computation in a variety of domains, so even functions with complicated implementations
should lean toward in-scope, so long as the implementations are well-designed and robust.
There are a few existing function wrappers that we should take a closer look at; for example:
jax.scipy.special.lpmn(): this generates legendre polynomials via a complicated fori_loop, in a way that does not match the scipy API (e.g. for
zmust be a scalar, while for JAX,
zmust be a 1D array). The function has few discoverable uses making it a weak candidate along Axes 1, 2, 4, and 6.
jax.scipy.special.lpmn_values(): this has similar weaknesses to
jax.scipy.special.sph_harm(): this is built on lpmn, and similarly has an API that diverges from the corresponding
jax.scipy.special.bessel_jn(): as discussed under Axis 4 above, this has weaknesses in terms of robustness of implementation and little usage. We might consider replacing it with a new, more robust implementation (e.g. #17038).
Recommendation: refactor and improve robustness & test coverage for
bessel_jn. Consider deprecating
sph_harm if they cannot be modified to more closely match the
scipy.stats module contains a wide range of statistical functions, including discrete
and continuous distributions, summary statistics, and hypothesis testing. JAX currently wraps
a number of these in
jax.scipy.stats, primarily including 20 or so statistical distributions,
along with a few other functions (
gaussian_kde). In general these are
well-aligned with JAX: distributions usually are expressible in terms of efficient XLA operations,
and the APIs are clean and functional.
We don’t currently have any wrappers for hypothesis testing functions, probably because these are less useful to the primary user-base of JAX.
Regarding distributions, in some cases,
tensorflow_probability provides similar functionality,
and in the future we might consider whether to deprecate the scipy.stats distributions in favor
of that implementation.
Recommendation: going forward, we should treat statistical distributions and summary statistics as in-scope, and consider hypothesis tests and related functionality generally out-of-scope.