jax.custom_gradient#
- jax.custom_gradient(fun)[source]#
Convenience function for defining custom VJP rules (aka custom gradients).
While the canonical way to define custom VJP rules is via
jax.custom_vjp
, thecustom_gradient
convenience wrapper follows TensorFlow’stf.custom_gradient
API. The difference here is thatcustom_gradient
can be used as a decorator on one function that returns both the primal value (representing the output of the mathematical function to be differentiated) and the VJP (gradient) function. See https://www.tensorflow.org/api_docs/python/tf/custom_gradient.If the mathematical function to be differentiated has Haskell-like signature
a -> b
, then the Python callablefun
should have the signaturea -> (b, CT b --o CT a)
where we useCT x
to denote a cotangent type forx
and the--o
arrow to denote a linear function. See the example below. That is,fun
should return a pair where the first element represents the value of the mathematical function to be differentiated and the second element is a function to be called on the backward pass of reverse-mode automatic differentiation (i.e. the “custom gradient” function).The function returned as the second element of the output of
fun
can close over intermediate values computed when evaluating the function to be differentiated. That is, use lexical closure to share work between the forward pass and the backward pass of reverse-mode automatic differentiation. However, it cannot perform Python control flow which depends on the values of the closed-over intermediate values or its cotangent arguments; if the function includes such control flow, an error is raised.- Parameters:
fun – a Python callable specifying both the mathematical function to be differentiated and its reverse-mode differentiation rule. It should return a pair consisting of an output value and a Python callable that represents the custom gradient function.
- Returns:
A Python callable that accepts the same arguments as
fun
and returns the output value specified by the first element offun
’s output pair.
For example:
>>> @jax.custom_gradient ... def f(x): ... return x ** 2, lambda g: (g * x,) ... >>> print(f(3.)) 9.0 >>> print(jax.grad(f)(3.)) 3.0
An example with a function on two arguments, so that the VJP function must return a tuple of length two:
>>> @jax.custom_gradient ... def f(x, y): ... return x * y, lambda g: (g * y, g * x) ... >>> print(f(3., 4.)) 12.0 >>> print(jax.grad(f, argnums=(0, 1))(3., 4.)) (Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))