jax.custom_jvp

jax.custom_jvp#

class jax.custom_jvp(fun, nondiff_argnums=())[source]#

Set up a JAX-transformable function for a custom JVP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like jax.jvp() or jax.grad()) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.

There are two instance methods available for defining the custom JVP rule: defjvp() for defining a single custom JVP rule for all the function’s inputs, and for convenience defjvps(), which wraps defjvp(), and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.

For example:

@jax.custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out

For a more detailed introduction, see the tutorial.

Parameters:
__init__(fun, nondiff_argnums=())[source]#
Parameters:

Methods

__init__(fun[, nondiff_argnums])

param fun:

defjvp(jvp[, symbolic_zeros])

Define a custom JVP rule for the function represented by this instance.

defjvps(*jvps)

Convenience wrapper for defining JVPs for each argument separately.

Attributes

jvp

symbolic_zeros

fun

nondiff_argnums