jax.jacfwd

Contents

jax.jacfwd#

jax.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

Jacobian of fun evaluated column-by-column using forward-mode AD.

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed.

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns:

A function with the same arguments as fun, that evaluates the Jacobian of fun using forward-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

Return type:

Callable

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]