jax.value_and_grad#
- jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
Create a function that evaluates both
fun
and the gradient offun
.- Parameters:
fun (Callable) – Function to be differentiated. Its arguments at positions specified by
argnums
should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux (bool) – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (bool) – Optional, bool. Indicates whether
fun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
reduce_axes (Sequence[AxisName])
- Returns:
A function with the same arguments as
fun
that evaluates bothfun
and the gradient offun
and returns them as a pair (a two-element tuple). Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_aux
is True then a tuple of ((value, auxiliary_data), gradient) is returned.- Return type:
Callable[…, tuple[Any, Any]]