jax.experimental.sparse.sparsify#
- jax.experimental.sparse.sparsify(f, use_tracer=False)[source]#
Experimental sparsification transform.
Examples
Decorate JAX functions to make them compatible with
jax.experimental.sparse.BCOO
matrices:>>> from jax.experimental import sparse
>>> @sparse.sparsify ... def f(M, v): ... return 2 * M.T @ v
>>> M = sparse.BCOO.fromdense(jnp.arange(12).reshape(3, 4))
>>> v = jnp.array([3, 4, 2])
>>> f(M, v) Array([ 64, 82, 100, 118], dtype=int32)