jax.vjp(fun: Callable[[...], jax._src.api.T], *primals: Any, has_aux: Literal[False] = 'False', reduce_axes: Sequence[Any] = '()') Tuple[jax._src.api.T, Callable][source]#
jax.vjp(fun: Callable[[...], Tuple[jax._src.api.T, jax._src.api.U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[Any] = '()') Tuple[jax._src.api.T, Callable, jax._src.api.U]

Compute a (reverse-mode) vector-Jacobian product of fun.

grad() is implemented as a special case of vjp().

  • fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.

  • primals – A sequence of primal values at which the Jacobian of fun should be evaluated. The length of primals should be equal to the number of positional parameters to fun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if 'batch' is a named batch axis, vjp(f, *args, reduce_axes=('batch',)) will create a VJP function that sums over the batch while vjp(f, *args) will create a per-example VJP.

Return type

Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]


If has_aux is False, returns a (primals_out, vjpfun) pair, where primals_out is fun(*primals). vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vector-Jacobian product of fun evaluated at primals. If has_aux is True, returns a (primals_out, vjpfun, aux) tuple where aux is the auxiliary data returned by fun.

>>> import jax
>>> def f(x, y):
...   return jax.numpy.sin(x), jax.numpy.cos(y)
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
>>> print(ybar)