jax.jacfwd#
- jax.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
Jacobian of
fun
evaluated column-by-column using forward-mode AD.- Parameters
fun (
Callable
) – Function whose Jacobian is to be computed.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.
- Return type
- Returns
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using forward-mode automatic differentiation. Ifhas_aux
is True then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]