- class jax.custom_jvp(fun, nondiff_argnums=())#
Set up a JAX-transformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.grad()) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.
There are two instance methods available for defining the custom JVP rule:
defjvp()for defining a single custom JVP rule for all the function’s inputs, and for convenience
defjvps(), which wraps
defjvp(), and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.
- __init__(fun, nondiff_argnums=())#
- param fun:
Define a custom JVP rule for the function represented by this instance.
Convenience wrapper for defining JVPs for each argument separately.