jax.value_and_grad#

jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#

Create a function that evaluates both fun and the gradient of fun.

Parameters
  • fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

  • reduce_axes (Sequence[Any]) – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if 'batch' is a named batch axis, value_and_grad(f, reduce_axes=('batch',)) will create a function that computes the total gradient while value_and_grad(f) will create one that computes the per-example gradient.

Return type

Callable[…, Tuple[Any, Any]]

Returns

A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.