jax.jacrev#
- jax.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#
Jacobian of
fun
evaluated row-by-row using reverse-mode AD.- Parameters:
fun (Callable) – Function whose Jacobian is to be computed.
argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default
0
).has_aux (bool) – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (bool) – Optional, bool. Indicates whether
fun
is promised to be holomorphic. Default False.allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Returns:
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using reverse-mode automatic differentiation. Ifhas_aux
is True then a pair of (jacobian, auxiliary_data) is returned.- Return type:
Callable
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]