jax.experimental.sparse.value_and_grad

Contents

jax.experimental.sparse.value_and_grad#

jax.experimental.sparse.value_and_grad(fun, argnums=0, has_aux=False, **kwargs)[source]#

Sparse-aware version of jax.value_and_grad()

Arguments and return values are the same as jax.value_and_grad(), but when taking the gradient with respect to a jax.experimental.sparse array, the gradient is computed in the subspace defined by the array’s sparsity pattern.

Example

>>> from jax.experimental import sparse
>>> X = sparse.BCOO.fromdense(jnp.arange(6.))
>>> y = jnp.ones(6)
>>> sparse.value_and_grad(lambda X, y: X @ y)(X, y)
(Array(15., dtype=float32), BCOO(float32[6], nse=5))
Parameters:
  • fun (Callable)

  • argnums (int | Sequence[int])

Return type:

Callable[…, tuple[Any, Any]]