- jax.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())#
Creates a function that evaluates the gradient of
fun (Callable) – Function to be differentiated. Its arguments at positions specified by
argnumsshould be arrays, scalars, or standard Python containers. Argument arrays in the positions specified by
argnumsmust be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape
()but not arrays with shape
has_aux (bool) – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
holomorphic (bool) – Optional, bool. Indicates whether
funis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
reduce_axes (Sequence[AxisName]) – Optional, tuple of axis names. If an axis is listed here, and
funimplicitly broadcasts a value over that axis, the backward pass will perform a
psumof the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if
'batch'is a named batch axis,
grad(f, reduce_axes=('batch',))will create a function that computes the total gradient while
grad(f)will create one that computes the per-example gradient.
- Return type:
A function with the same arguments as
fun, that evaluates the gradient of
argnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If
has_auxis True then a pair of (gradient, auxiliary_data) is returned.
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043