jax.experimental.sparse.sparsify

Contents

jax.experimental.sparse.sparsify#

jax.experimental.sparse.sparsify(f, use_tracer=False)[source]#

Experimental sparsification transform.

Examples

Decorate JAX functions to make them compatible with jax.experimental.sparse.BCOO matrices:

>>> from jax.experimental import sparse
>>> @sparse.sparsify
... def f(M, v):
...   return 2 * M.T @ v
>>> M = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
>>> v = jnp.array([3, 4, 2])
>>> f(M, v)
Array([ 64,  82, 100, 118], dtype=int32)