jax.experimental.sparse module#

The jax.experimental.sparse module includes experimental support for sparse matrix operations in JAX. It is under active development, and the API is subject to change. The primary interfaces made available are the BCOO sparse array type, and the sparsify() transform.

Batched-coordinate (BCOO) sparse matrices#

The main high-level sparse object currently available in JAX is the BCOO, or batched coordinate sparse array, which offers a compressed storage format compatible with JAX transformations, in particular JIT (e.g. jax.jit()), batching (e.g. jax.vmap()) and autodiff (e.g. jax.grad()).

Here is an example of creating a sparse array from a dense array:

>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
...                [3., 0., 0., 0.],
...                [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)

Convert back to a dense array with the todense() method:

>>> M_sp.todense()
DeviceArray([[0., 1., 0., 2.],
             [3., 0., 0., 0.],
             [0., 0., 4., 0.]], dtype=float32)

The BCOO format is a somewhat modified version of the standard COO format, and the dense representation can be seen in the data and indices attributes:

>>> M_sp.data  # Explicitly stored data
DeviceArray([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
DeviceArray([[0, 1],
             [0, 3],
             [1, 0],
             [2, 2]], dtype=int32)

BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:

>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse  # "number of specified elements"
4

BCOO objects also implement a number of array-like methods, to allow you to use them directly within jax programs. For example, here we compute the transposed matrix-vector product:

>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
DeviceArray([18.,  3., 20.,  6.], dtype=float32)
>>> M.T @ y  # Compare to dense version
DeviceArray([18.,  3., 20.,  6.], dtype=float32)

BCOO objects are designed to be compatible with JAX transforms, including jax.jit(), jax.vmap(), jax.grad(), and others. For example:

>>> from jax import grad, jit
>>> def f(y):
...   return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
DeviceArray([3., 3., 4.], dtype=float32)

Note, however, that under normal circumstances jax.numpy and jax.lax functions do not know how to handle sparse matrices, so attempting to compute things like jnp.dot(M_sp.T, y) will result in an error (however, see the next section).

Sparsify transform#

An overarching goal of the JAX sparse implementation is to provide a means to switch from dense to sparse computation seamlessly, without having to modify the dense implementation. This sparse experiment accomplishes this through the sparsify() transform.

Consider this function, which computes a more complicated result from a matrix and a vector input:

>>> def f(M, v):
...   return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
DeviceArray([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Were we to pass a sparse matrix to this directly, it would result in an error, because jnp functions do not recognize sparse inputs. However, with sparsify(), we get a version of this function that does accept sparse matrices:

>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
DeviceArray([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Currently support for sparsify() is limited to a couple dozen primitives, including:

  • generalized matrix-matrix products (dot_general_p)

  • generalized array transpose (transpose_p)

  • zero-preserving elementwise binary operations (add_p, mul_p)

  • zero-preserving elementwise unary operations (abs_p, jax.lax.neg_p, etc.)

  • summation reductions (lax.reduce_sum_p)

  • some higher-order functions (lax.cond_p, lax.while_p, lax.scan_p)

This initial support is enough to enable some surprisingly sophisticated workflows, as the next section will show.

Example: sparse logistic regression#

As an example of a more complicated sparse workflow, let’s consider a simple logistic regression implemented in JAX. Notice that the following implementation has no reference to sparsity:

>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
...   return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
...   return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
...   y_hat = y_model(params, X)
...   return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
...   params = jnp.zeros(X.shape[1] + 1)
...   result = optimize.minimize(functools.partial(loss, X=X, y=y),
...                              x0=params, method='BFGS')
...   return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)  
[-0.7298445   0.29893667  1.0248291  -0.44436368  0.8785025  -0.7724008
 -0.62893456  0.2934014   0.82974285  0.16838408 -0.39774987 -0.5071844
  0.2028872   0.5227761  -0.3739224  -0.7104083   2.4212713   0.6310087
 -0.67060554  0.03139788 -0.05359547]

This returns the best-fit parameters of a dense logistic regression problem. To fit the same model on sparse data, we can apply the sparsify() transform:

>>> Xsp = sparse.BCOO.fromdense(X)  # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg)  # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)  
[-0.72971725  0.29878938  1.0246326  -0.44430563  0.8784217  -0.77225566
 -0.6288222   0.29335397  0.8293481   0.16820715 -0.39764675 -0.5069753
  0.202579    0.522672   -0.3740134  -0.7102678   2.4209507   0.6310593
 -0.670236    0.03132951 -0.05356663]

API#

BCOO(args, *, shape[, indices_sorted, ...])

Experimental batched COO matrix implemented in JAX

sparsify(f[, use_tracer])

Experimental sparsification transform.

bcoo_broadcast_in_dim(mat, *, shape, ...)

Expand the size and rank of a BCOO array by duplicating the data.

bcoo_concatenate(operands, *, dimension)

Sparse implementation of jax.lax.concatenate()

bcoo_dot_general(lhs, rhs, *, dimension_numbers)

A general contraction operation.

bcoo_dot_general_sampled(A, B, indices, *, ...)

A contraction operation with output computed at given sparse indices.

bcoo_extract(indices, mat)

Extract BCOO data values from a dense matrix at given BCOO indices.

bcoo_fromdense(mat, *[, nse, n_batch, ...])

Create BCOO-format sparse matrix from a dense matrix.

bcoo_multiply_dense(sp_mat, v)

An element-wise multiplication between a sparse and a dense array.

bcoo_multiply_sparse(lhs, rhs)

An element-wise multiplication of two sparse arrays.

bcoo_reduce_sum(mat, *, axes)

Sum array element over given axes.

bcoo_reshape(mat, *, new_sizes, dimensions)

Sparse implementation of {func}`jax.lax.reshape`.

bcoo_sort_indices(mat)

Sort indices of a BCOO array.

bcoo_sum_duplicates(mat[, nse])

Sums duplicate indices within a BCOO array, returning an array with sorted indices.

bcoo_todense(mat)

Convert batched sparse matrix to a dense matrix.

bcoo_transpose(mat, *, permutation)

Transpose a BCOO-format array.