jax.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#

Jacobian of fun evaluated row-by-row using reverse-mode AD.

  • fun (Callable) – Function whose Jacobian is to be computed.

  • argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

Return type



A function with the same arguments as fun, that evaluates the Jacobian of fun using reverse-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

>>> import jax
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]