jax.experimental.sparse.grad

Contents

jax.experimental.sparse.grad#

jax.experimental.sparse.grad(fun, argnums=0, has_aux=False, **kwargs)[source]#

Sparse-aware version of jax.grad()

Arguments and return values are the same as jax.grad(), but when taking the gradient with respect to a jax.experimental.sparse array, the gradient is computed in the subspace defined by the array’s sparsity pattern.

Example

>>> from jax.experimental import sparse
>>> X = sparse.BCOO.fromdense(jnp.arange(6.))
>>> y = jnp.ones(6)
>>> sparse.grad(lambda X, y: X @ y)(X, y)
BCOO(float32[6], nse=5)
Parameters:
  • fun (Callable) –

  • argnums (int | Sequence[int]) –

Return type:

Callable