jax.experimental.sparse module

The jax.experimental.sparse module includes experimental support for sparse matrix operations in JAX. It is under active development, and the API is subject to change. The primary interfaces made available are the BCOO sparse array type, and the sparsify() transform.

Batched-coordinate (BCOO) sparse matrices

The main high-level sparse object currently available in JAX is the BCOO, or batched coordinate sparse array, which offers a compressed storage format compatible with JAX functions.

Here is an example of creating a sparse array from a dense array:

>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
...                [3., 0., 0., 0.],
...                [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)

Convert back to a dense array with the todense() method:

>>> M_sp.todense()
DeviceArray([[0., 1., 0., 2.],
             [3., 0., 0., 0.],
             [0., 0., 4., 0.]], dtype=float32)

The BCOO format is a somewhat modified version of the standard COO format, and the dense representation can be seen in the data and indices attributes:

>>> M_sp.data  # Explicitly stored data
DeviceArray([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
DeviceArray([[0, 1],
             [0, 3],
             [1, 0],
             [2, 2]], dtype=int32)

BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:

>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse  # "number of specified elements"
4

BCOO objects also implement a number of array-like methods, to allow you to use them directly within jax programs. For example, here we compute the transposed matrix-vector product:

>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
DeviceArray([18.,  3., 20.,  6.], dtype=float32)
>>> M.T @ y  # Compare to dense version
DeviceArray([18.,  3., 20.,  6.], dtype=float32)

BCOO objects are designed to be compatible with JAX transforms, including jax.jit(), jax.vmap(), jax.grad(), and others. For example:

>>> from jax import grad, jit
>>> def f(y):
...   return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
DeviceArray([3., 3., 4.], dtype=float32)

Note, however, that under normal circumstances jax.numpy and jax.lax functions do not know how to handle sparse matrices, so attempting to compute things like jnp.dot(M_sp.T, y) will result in an error (however, see the next section).

Sparsify transform

An overarching goal of the JAX sparse implementation is to provide a means to switch from dense to sparse computation seamlessly, without having to modify the dense implementation. This sparse experiment accomplishes this through the sparsify() transform.

Consider this function, which computes a more complicated result from a matrix and a vector input:

>>> def f(M, v):
...   return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
DeviceArray([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Were we to pass a sparse matrix to this directly, it would result in an error, because jnp functions do not recognize sparse inputs. However, with sparsify(), we get a version of this function that does accept sparse matrices:

>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
DeviceArray([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Currently support for sparsify() is limited to a couple dozen primitives, including:

  • generalized matrix-matrix products (dot_general_p)

  • generalized array transpose (transpose_p)

  • zero-preserving elementwise binary operations (add_p, mul_p)

  • zero-preserving elementwise unary operations (abs_p, jax.lax.neg_p, etc.)

  • summation reductions (lax.reduce_sum_p)

  • some higher-order functions (lax.cond_p, lax.while_p, lax.scan_p)

This initial support is enough to enable some surprisingly sophisticated workflows, as the next section will show.

Example: sparse logistic regression

As an example of a more complicated sparse workflow, let’s consider a simple logistic regression implemented in JAX. Notice that the following implementation has no reference to sparsity:

>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
...   return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
...   return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
...   y_hat = y_model(params, X)
...   return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
...   params = jnp.zeros(X.shape[1] + 1)
...   result = optimize.minimize(functools.partial(loss, X=X, y=y),
...                              x0=params, method='BFGS')
...   return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)  
[-0.7298445   0.29893667  1.0248291  -0.44436368  0.8785025  -0.7724008
 -0.62893456  0.2934014   0.82974285  0.16838408 -0.39774987 -0.5071844
  0.2028872   0.5227761  -0.3739224  -0.7104083   2.4212713   0.6310087
 -0.67060554  0.03139788 -0.05359547]

This returns the best-fit parameters of a dense logistic regression problem. To fit the same model on sparse data, we can apply the sparsify() transform:

>>> Xsp = sparse.BCOO.fromdense(X)  # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg)  # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)  
[-0.72971725  0.29878938  1.0246326  -0.44430563  0.8784217  -0.77225566
 -0.6288222   0.29335397  0.8293481   0.16820715 -0.39764675 -0.5069753
  0.202579    0.522672   -0.3740134  -0.7102678   2.4209507   0.6310593
 -0.670236    0.03132951 -0.05356663]

API

class jax.experimental.sparse.BCOO(args, *, shape)[source]

Experimental batched COO matrix implemented in JAX

Parameters
  • (data – data and indices in batched COO format.

  • indices) – data and indices in batched COO format.

  • shape – shape of sparse array.

data

ndarray of shape [*batch_dims, nse, *dense_dims] containing the explicitly stored data within the sparse matrix.

Type

jax._src.numpy.lax_numpy.ndarray

indices

ndarray of shape [*batch_dims, nse, n_sparse] containing the indices of the explicitly stored data. Duplicate entries will be summed.

Type

jax._src.numpy.lax_numpy.ndarray

Examples

Create a sparse array from a dense array:

>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]])
>>> M_sp = BCOO.fromdense(M)
>>> M_sp
BCOO(float32[2, 3], nse=3)

Examine the internal representation:

>>> M_sp.data
DeviceArray([2., 1., 4.], dtype=float32)
>>> M_sp.indices
DeviceArray([[0, 1],
             [1, 0],
             [1, 2]], dtype=int32)

Create a dense array from a sparse array:

>>> M_sp.todense()
DeviceArray([[0., 2., 0.],
             [1., 0., 4.]], dtype=float32)

Create a sparse array from COO data & indices:

>>> data = jnp.array([1., 3., 5.])
>>> indices = jnp.array([[0, 0],
...                      [1, 1],
...                      [2, 2]])
>>> mat = BCOO((data, indices), shape=(3, 3))
>>> mat
BCOO(float32[3, 3], nse=3)
>>> mat.todense()
DeviceArray([[1., 0., 0.],
             [0., 3., 0.],
             [0., 0., 5.]], dtype=float32)
jax.experimental.sparse.sparsify(f, use_tracer=False)[source]

Experimental sparsification transform.

Examples

Decorate JAX functions to make them compatible with jax.experimental.sparse.BCOO matrices:

>>> from jax.experimental import sparse
>>> @sparse.sparsify
... def f(M, v):
...   return 2 * M.T @ v
>>> M = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
>>> v = jnp.array([3, 4, 2])
>>> f(M, v)
DeviceArray([ 64,  82, 100, 118], dtype=int32)