jax.custom_jvp.defjvp#
- custom_jvp.defjvp(jvp, symbolic_zeros=False)[source]#
Define a custom JVP rule for the function represented by this instance.
- Parameters:
jvp (Callable[..., tuple[ReturnValue, ReturnValue]]) – a Python callable representing the custom JVP rule. When there are no
nondiff_argnums
, thejvp
function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of parameters of thecustom_jvp
function. Thejvp
function should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof.symbolic_zeros (bool) – boolean, indicating whether the rule should be passed objects representing static symbolic zeros in its tangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed. Setting this option to
True
allows a JVP rule to detect whether certain inputs are not involved in differentiation, but at the cost of needing special handling for these objects (which e.g. can’t be passed into jax.numpy functions). DefaultFalse
.
- Returns:
Returns
jvp
so thatdefjvp
can be used as a decorator.- Return type:
Callable[…, tuple[ReturnValue, ReturnValue]]
Examples
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.defjvp ... def f_jvp(primals, tangents): ... x, y = primals ... x_dot, y_dot = tangents ... primal_out = f(x, y) ... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot ... return primal_out, tangent_out
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))