# jax.jacrev#

jax.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#

Jacobian of `fun` evaluated row-by-row using reverse-mode AD.

Parameters:
• fun (Callable) â€“ Function whose Jacobian is to be computed.

• argnums (int | Sequence[int]) â€“ Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default `0`).

• has_aux (bool) â€“ Optional, bool. Indicates whether `fun` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

• holomorphic (bool) â€“ Optional, bool. Indicates whether `fun` is promised to be holomorphic. Default False.

• allow_int (bool) â€“ Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

Returns:

A function with the same arguments as `fun`, that evaluates the Jacobian of `fun` using reverse-mode automatic differentiation. If `has_aux` is True then a pair of (jacobian, auxiliary_data) is returned.

Return type:

Callable

```>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
[ 0.       0.       5.     ]
[ 0.      16.      -2.     ]
[ 1.6209   0.       0.84147]]
```