jax.experimental.sparse.value_and_grad#
- jax.experimental.sparse.value_and_grad(fun, argnums=0, has_aux=False, **kwargs)[source]#
Sparse-aware version of
jax.value_and_grad()
Arguments and return values are the same as
jax.value_and_grad()
, but when taking the gradient with respect to ajax.experimental.sparse
array, the gradient is computed in the subspace defined by the array’s sparsity pattern.Examples
>>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) >>> y = jnp.ones(6) >>> sparse.value_and_grad(lambda X, y: X @ y)(X, y) (Array(15., dtype=float32), BCOO(float32[6], nse=5))