jax.experimental.sparse module#

The jax.experimental.sparse module includes experimental support for sparse matrix operations in JAX. It is under active development, and the API is subject to change. The primary interfaces made available are the BCOO sparse array type, and the sparsify() transform.

Batched-coordinate (BCOO) sparse matrices#

The main high-level sparse object currently available in JAX is the BCOO, or batched coordinate sparse array, which offers a compressed storage format compatible with JAX transformations, in particular JIT (e.g. jax.jit()), batching (e.g. jax.vmap()) and autodiff (e.g. jax.grad()).

Here is an example of creating a sparse array from a dense array:

>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
...                [3., 0., 0., 0.],
...                [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)

Convert back to a dense array with the todense() method:

>>> M_sp.todense()
Array([[0., 1., 0., 2.],
       [3., 0., 0., 0.],
       [0., 0., 4., 0.]], dtype=float32)

The BCOO format is a somewhat modified version of the standard COO format, and the dense representation can be seen in the data and indices attributes:

>>> M_sp.data  # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
       [0, 3],
       [1, 0],
       [2, 2]], dtype=int32)

BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:

>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse  # "number of specified elements"
4

BCOO objects also implement a number of array-like methods, to allow you to use them directly within jax programs. For example, here we compute the transposed matrix-vector product:

>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18.,  3., 20.,  6.], dtype=float32)
>>> M.T @ y  # Compare to dense version
Array([18.,  3., 20.,  6.], dtype=float32)

BCOO objects are designed to be compatible with JAX transforms, including jax.jit(), jax.vmap(), jax.grad(), and others. For example:

>>> from jax import grad, jit
>>> def f(y):
...   return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)

Note, however, that under normal circumstances jax.numpy and jax.lax functions do not know how to handle sparse matrices, so attempting to compute things like jnp.dot(M_sp.T, y) will result in an error (however, see the next section).

Sparsify transform#

An overarching goal of the JAX sparse implementation is to provide a means to switch from dense to sparse computation seamlessly, without having to modify the dense implementation. This sparse experiment accomplishes this through the sparsify() transform.

Consider this function, which computes a more complicated result from a matrix and a vector input:

>>> def f(M, v):
...   return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Were we to pass a sparse matrix to this directly, it would result in an error, because jnp functions do not recognize sparse inputs. However, with sparsify(), we get a version of this function that does accept sparse matrices:

>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

Support for sparsify() includes a large number of the most common primitives, including:

  • generalized (batched) matrix products & einstein summations (dot_general_p)

  • zero-preserving elementwise binary operations (e.g. add_p, mul_p, etc.)

  • zero-preserving elementwise unary operations (e.g. abs_p, jax.lax.neg_p, etc.)

  • summation reductions (reduce_sum_p)

  • general indexing operations (slice_p, lax.dynamic_slice_p, lax.gather_p)

  • concatenation and stacking (concatenate_p)

  • transposition & reshaping ((transpose_p, reshape_p, squeeze_p, broadcast_in_dim_p)

  • some higher-order functions (cond_p, while_p, scan_p)

  • some simple 1D convolutions (conv_general_dilated_p)

Nearly any jax.numpy function that lowers to these supported primitives can be used within a sparsify transform to operate on sparse arrays. This set of primitives is enough to enable relatively sophisticated sparse workflows, as the next section will show.

Example: sparse logistic regression#

As an example of a more complicated sparse workflow, let’s consider a simple logistic regression implemented in JAX. Notice that the following implementation has no reference to sparsity:

>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
...   return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
...   return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
...   y_hat = y_model(params, X)
...   return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
...   params = jnp.zeros(X.shape[1] + 1)
...   result = optimize.minimize(functools.partial(loss, X=X, y=y),
...                              x0=params, method='BFGS')
...   return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)  
[-0.7298445   0.29893667  1.0248291  -0.44436368  0.8785025  -0.7724008
 -0.62893456  0.2934014   0.82974285  0.16838408 -0.39774987 -0.5071844
  0.2028872   0.5227761  -0.3739224  -0.7104083   2.4212713   0.6310087
 -0.67060554  0.03139788 -0.05359547]

This returns the best-fit parameters of a dense logistic regression problem. To fit the same model on sparse data, we can apply the sparsify() transform:

>>> Xsp = sparse.BCOO.fromdense(X)  # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg)  # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)  
[-0.72971725  0.29878938  1.0246326  -0.44430563  0.8784217  -0.77225566
 -0.6288222   0.29335397  0.8293481   0.16820715 -0.39764675 -0.5069753
  0.202579    0.522672   -0.3740134  -0.7102678   2.4209507   0.6310593
 -0.670236    0.03132951 -0.05356663]

Sparse API Reference#

sparsify(f[, use_tracer])

Experimental sparsification transform.

grad(fun[, argnums, has_aux])

Sparse-aware version of jax.grad()

value_and_grad(fun[, argnums, has_aux])

Sparse-aware version of jax.value_and_grad()

empty(shape[, dtype, index_dtype, sparse_format])

Create an empty sparse array.

eye(N[, M, k, dtype, index_dtype, sparse_format])

Create 2D sparse identity matrix.

todense(arr)

Convert input to a dense matrix.

random_bcoo(key, shape, *[, dtype, ...])

Generate a random BCOO matrix.

JAXSparse(args, *, shape)

Base class for high-level JAX sparse objects.

BCOO Data Structure#

BCOO is the Batched COO format, and is the main sparse data structure implemented in jax.experimental.sparse. Its operations are compatible with JAX’s core transformations, including batching (e.g. jax.vmap()) and autodiff (e.g. jax.grad()).

BCOO(args, *, shape[, indices_sorted, ...])

Experimental batched COO matrix implemented in JAX

bcoo_broadcast_in_dim(mat, *, shape, ...)

Expand the size and rank of a BCOO array by duplicating the data.

bcoo_concatenate(operands, *, dimension)

Sparse implementation of jax.lax.concatenate()

bcoo_dot_general(lhs, rhs, *, dimension_numbers)

A general contraction operation.

bcoo_dot_general_sampled(A, B, indices, *, ...)

A contraction operation with output computed at given sparse indices.

bcoo_dynamic_slice(mat, start_indices, ...)

Sparse implementation of {func}`jax.lax.dynamic_slice`.

bcoo_extract(sparr, arr, *[, assume_unique])

Extract values from a dense array according to the sparse array's indices.

bcoo_fromdense(mat, *[, nse, n_batch, ...])

Create BCOO-format sparse matrix from a dense matrix.

bcoo_gather(operand, start_indices, ...[, ...])

BCOO version of lax.gather.

bcoo_multiply_dense(sp_mat, v)

An element-wise multiplication between a sparse and a dense array.

bcoo_multiply_sparse(lhs, rhs)

An element-wise multiplication of two sparse arrays.

bcoo_update_layout(mat, *[, n_batch, ...])

Update the storage layout (i.e. n_batch & n_dense) of a BCOO matrix.

bcoo_reduce_sum(mat, *, axes)

Sum array element over given axes.

bcoo_reshape(mat, *, new_sizes[, dimensions])

Sparse implementation of {func}`jax.lax.reshape`.

bcoo_slice(mat, *, start_indices, limit_indices)

Sparse implementation of {func}`jax.lax.slice`.

bcoo_sort_indices(mat)

Sort indices of a BCOO array.

bcoo_squeeze(arr, *, dimensions)

Sparse implementation of {func}`jax.lax.squeeze`.

bcoo_sum_duplicates(mat[, nse])

Sums duplicate indices within a BCOO array, returning an array with sorted indices.

bcoo_todense(mat)

Convert batched sparse matrix to a dense matrix.

bcoo_transpose(mat, *, permutation)

Transpose a BCOO-format array.

BCSR Data Structure#

BCSR is the Batched Compressed Sparse Row format, and is under development. Its operations are compatible with JAX’s core transformations, including batching (e.g. jax.vmap()) and autodiff (e.g. jax.grad()).

BCSR(args, *, shape[, indices_sorted, ...])

Experimental batched CSR matrix implemented in JAX.

bcsr_dot_general(lhs, rhs, *, dimension_numbers)

A general contraction operation.

bcsr_extract(indices, indptr, mat)

Extract values from a dense matrix at given BCSR (indices, indptr).

bcsr_fromdense(mat, *[, nse, n_batch, ...])

Create BCSR-format sparse matrix from a dense matrix.

bcsr_todense(mat)

Convert batched sparse matrix to a dense matrix.

Other Sparse Data Structures#

Other sparse data structures include COO, CSR, and CSC. These are reference implementations of simple sparse structures with a few core operations implemented. Their operations are generally compatible with autodiff transformations such as jax.grad(), but not with batching transforms like jax.vmap().

COO(args, *, shape[, rows_sorted, cols_sorted])

Experimental COO matrix implemented in JAX.

CSC(args, *, shape)

Experimental CSC matrix implemented in JAX; API subject to change.

CSR(args, *, shape)

Experimental CSR matrix implemented in JAX.

coo_fromdense(mat, *[, nse, index_dtype])

Create a COO-format sparse matrix from a dense matrix.

coo_matmat(mat, B, *[, transpose])

Product of COO sparse matrix and a dense matrix.

coo_matvec(mat, v[, transpose])

Product of COO sparse matrix and a dense vector.

coo_todense(mat)

Convert a COO-format sparse matrix to a dense matrix.

csr_fromdense(mat, *[, nse, index_dtype])

Create a CSR-format sparse matrix from a dense matrix.

csr_matmat(mat, B, *[, transpose])

Product of CSR sparse matrix and a dense matrix.

csr_matvec(mat, v[, transpose])

Product of CSR sparse matrix and a dense vector.

csr_todense(mat)

Convert a CSR-format sparse matrix to a dense matrix.

jax.experimental.sparse.linalg#

Sparse linear algebra routines.

spsolve(data, indices, indptr, b[, tol, reorder])

A sparse direct solver using QR factorization.

lobpcg_standard(A, X[, m, tol])

Compute the top-k standard eigenvalues using the LOBPCG routine.