jax.custom_jvp#
- class jax.custom_jvp(fun, nondiff_argnums=())[source]#
Set up a JAX-transformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.jvp()
orjax.grad()
) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.There are two instance methods available for defining the custom JVP rule:
defjvp()
for defining a single custom JVP rule for all the function’s inputs, and for conveniencedefjvps()
, which wrapsdefjvp()
, and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.For example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.
- Parameters:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
- __init__(fun, nondiff_argnums=())[source]#
- Parameters:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
Methods
__init__
(fun[, nondiff_argnums])defjvp
(jvp[, symbolic_zeros])Define a custom JVP rule for the function represented by this instance.
defjvps
(*jvps)Convenience wrapper for defining JVPs for each argument separately.
Attributes
jvp
symbolic_zeros
fun
nondiff_argnums