{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "LqiaKasFjH82" }, "source": [ "# Custom derivative rules for JAX-transformable Python functions\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", "1. using jax.custom_jvp and jax.custom_vjp to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new core.Primitive instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." ] }, { "cell_type": "markdown", "metadata": { "id": "9Fg3NFNY-2RY" }, "source": [ "## TL;DR" ] }, { "cell_type": "markdown", "metadata": { "id": "ZgMNRtXyWIW8" }, "source": [ "### Custom JVPs with jax.custom_jvp" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "zXic8tr--1PK" }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "from jax import custom_jvp\n", "\n", "@custom_jvp\n", "def f(x, y):\n", " return jnp.sin(x) * y\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " x, y = primals\n", " x_dot, y_dot = tangents\n", " primal_out = f(x, y)\n", " tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot\n", " return primal_out, tangent_out" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "RrNf588X_kJF", "outputId": "b962bafb-e8a3-4b0d-ddf4-202e088231c3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.7278922\n", "2.7278922\n", "-1.2484405\n", "-1.2484405\n" ] } ], "source": [ "from jax import jvp, grad\n", "\n", "print(f(2., 3.))\n", "y, y_dot = jvp(f, (2., 3.), (1., 0.))\n", "print(y)\n", "print(y_dot)\n", "print(grad(f)(2., 3.))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "1kHd3cKOWQgB" }, "outputs": [], "source": [ "# Equivalent alternative using the defjvps convenience wrapper\n", "\n", "@custom_jvp\n", "def f(x, y):\n", " return jnp.sin(x) * y\n", "\n", "f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,\n", " lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "Zn81cHeYWVOw", "outputId": "bf29b66c-897b-485e-c0a0-ee0fbd729a95" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.7278922\n", "2.7278922\n", "-1.2484405\n", "-1.2484405\n" ] } ], "source": [ "print(f(2., 3.))\n", "y, y_dot = jvp(f, (2., 3.), (1., 0.))\n", "print(y)\n", "print(y_dot)\n", "print(grad(f)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "N2DOGCREWXFj" }, "source": [ "### Custom VJPs with jax.custom_vjp" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "35ScHqhrBwPh" }, "outputs": [], "source": [ "from jax import custom_vjp\n", "\n", "@custom_vjp\n", "def f(x, y):\n", " return jnp.sin(x) * y\n", "\n", "def f_fwd(x, y):\n", "# Returns primal output and residuals to be used in backward pass by f_bwd.\n", " return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n", "\n", "def f_bwd(res, g):\n", " cos_x, sin_x, y = res # Gets residuals computed in f_fwd\n", " return (cos_x * g * y, sin_x * g)\n", "\n", "f.defvjp(f_fwd, f_bwd)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "HpSozxKUCXgp", "outputId": "57277102-7bdb-41f0-c805-a27fcf9fb1ae" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-1.2484405\n" ] } ], "source": [ "print(grad(f)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "p5ypWA7XlZpu" }, "source": [ "## Example problems\n", "\n", "To get an idea of what problems jax.custom_jvp and jax.custom_vjp are meant to solve, let's go over a few examples. A more thorough introduction to the jax.custom_jvp and jax.custom_vjp APIs is in the next section." ] }, { "cell_type": "markdown", "metadata": { "id": "AR02eyd1GQhC" }, "source": [ "### Numerical stability\n", "\n", "One application of jax.custom_jvp is to improve the numerical stability of differentiation." ] }, { "cell_type": "markdown", "metadata": { "id": "GksPXslaGPaW" }, "source": [ "Say we want to write a function called log1pexp, which computes $x \\mapsto \\log ( 1 + e^x )$. We can write that using jax.numpy:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "6lWbTvs40ET-", "outputId": "8caff99e-add1-4c70-ace3-212c0c5c6f4e" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(3.0485873, dtype=float32)" ] }, "execution_count": 8, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "import jax.numpy as jnp\n", "\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", "\n", "log1pexp(3.)" ] }, { "cell_type": "markdown", "metadata": { "id": "PL36r_cD0oE8" }, "source": [ "Since it's written in terms of jax.numpy, it's JAX-transformable:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "XgtGKFld02UD", "outputId": "809d399d-8eca-401e-b969-810e46648571" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.0485873\n", "0.95257413\n", "[0.5 0.7310586 0.88079715]\n" ] } ], "source": [ "from jax import jit, grad, vmap\n", "\n", "print(jit(log1pexp)(3.))\n", "print(jit(grad(log1pexp))(3.))\n", "print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))" ] }, { "cell_type": "markdown", "metadata": { "id": "o56Nr3V61PKS" }, "source": [ "But there's a numerical stability problem lurking here:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "sVM6iwIO22sB", "outputId": "9c935ee8-f174-475a-ca01-fc80949199e5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nan\n" ] } ], "source": [ "print(grad(log1pexp)(100.))" ] }, { "cell_type": "markdown", "metadata": { "id": "Zu9sR2I73wuO" }, "source": [ "That doesn't seem right! After all, the derivative of $x \\mapsto \\log (1 + e^x)$ is $x \\mapsto \\frac{e^x}{1 + e^x}$, and so for large values of $x$ we'd expect the value to be about 1.\n", "\n", "We can get a bit more insight into what's going on by looking at the jaxpr for the gradient computation:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "dO6uZlYR4TVp", "outputId": "61e06b1e-14cd-4030-f330-a949be185df8" }, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a.\n", " let b = exp a\n", " c = add b 1.0\n", " _ = log c\n", " d = div 1.0 c\n", " e = mul d b\n", " in (e,) }" ] }, "execution_count": 11, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "from jax import make_jaxpr\n", "\n", "make_jaxpr(grad(log1pexp))(100.)" ] }, { "cell_type": "markdown", "metadata": { "id": "52HR5EW26PEt" }, "source": [ "Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and $\\infty$, respectively, which is never a good idea. That is, we're effectively evaluating lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x) for large x, which effectively turns into 0. * jnp.inf.\n", "\n", "Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n", "\n", "This problem is interesting because even though our definition of log1pexp could already be JAX-differentiated (and transformed with jit, vmap, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising log1pexp and composing the result. Instead, we'd like to specify how the whole function log1pexp should be differentiated, as a unit, and thus arrange those exponentials better.\n", "\n", "This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like jit, vmap, ...).\n", "\n", "Here's a solution using jax.custom_jvp:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "XQt6MAuTJewG" }, "outputs": [], "source": [ "from jax import custom_jvp\n", "\n", "@custom_jvp\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", "\n", "@log1pexp.defjvp\n", "def log1pexp_jvp(primals, tangents):\n", " x, = primals\n", " x_dot, = tangents\n", " ans = log1pexp(x)\n", " ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot\n", " return ans, ans_dot" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "rhiMHulfKBIF", "outputId": "883bc4d2-3a1b-48d3-b205-c500f77d229c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "print(grad(log1pexp)(100.))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "9cLDuAo6KGUu", "outputId": "59984494-6124-4540-84fd-608ad4fc6bc6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.0485873\n", "0.95257413\n", "[0.5 0.7310586 0.8807971]\n" ] } ], "source": [ "print(jit(log1pexp)(3.))\n", "print(jit(grad(log1pexp))(3.))\n", "print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))" ] }, { "cell_type": "markdown", "metadata": { "id": "9sVUGbGkUOqO" }, "source": [ "Here's a defjvps convenience wrapper to express the same thing:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "xfQTp8F7USEM" }, "outputs": [], "source": [ "@custom_jvp\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", "\n", "log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "dtdh-PLaUsvw", "outputId": "aa36aec6-15af-4397-fc55-8b9fb7e607d8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n", "3.0485873\n", "0.95257413\n", "[0.5 0.7310586 0.8807971]\n" ] } ], "source": [ "print(grad(log1pexp)(100.))\n", "print(jit(log1pexp)(3.))\n", "print(jit(grad(log1pexp))(3.))\n", "print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))" ] }, { "cell_type": "markdown", "metadata": { "id": "V9tHAfrSF1N-" }, "source": [ "### Enforcing a differentiation convention\n", "\n", "A related application is to enforce a differentiation convention, perhaps at a boundary." ] }, { "cell_type": "markdown", "metadata": { "id": "l_6tdb-QGK-H" }, "source": [ "Consider the function $f : \\mathbb{R}_+ \\mapsto \\mathbb{R}_+$ with $f(x) = \\frac{x}{1 + \\sqrt{x}}$, where we take $\\mathbb{R}_+ = [0, \\infty)$. We might implement $f$ as a program like this:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "AfF5P7x_GaSe" }, "outputs": [], "source": [ "def f(x):\n", " return x / (1 + jnp.sqrt(x))" ] }, { "cell_type": "markdown", "metadata": { "id": "BVcEkF3ZGgv1" }, "source": [ "As a mathematical function on $\\mathbb{R}$ (the full real line), $f$ is not differentiable at zero (because the limit defining the derivative doesn't exist from the left). Correspondingly, autodiff produces a nan value:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "piI0u5MiHhQh", "outputId": "c045308f-2f3b-4c22-ebb2-b9ee582b4d25" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nan\n" ] } ], "source": [ "print(grad(f)(0.))" ] }, { "cell_type": "markdown", "metadata": { "id": "IP0H2b7ZHkzD" }, "source": [ "But mathematically if we think of $f$ as a function on $\\mathbb{R}_+$ then it is differentiable at 0 [Rudin's Principles of Mathematical Analysis Definition 5.1, or Tao's Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function grad(f) to return at 0.0, namely 1.0. By default, JAX's machinery for differentiation assumes all functions are defined over $\\mathbb{R}$ and thus doesn't produce 1.0 here.\n", "\n", "We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function $x \\mapsto \\frac{\\sqrt{x} + 2}{2(\\sqrt{x} + 1)^2}$ on $\\mathbb{R}_+$," ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "ksHmCkcSKQJr" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x):\n", " return x / (1 + jnp.sqrt(x))\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " x, = primals\n", " x_dot, = tangents\n", " ans = f(x)\n", " ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot\n", " return ans, ans_dot" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "Gsh9ZvMTKi1O", "outputId": "a3076175-6542-4210-ce4a-d0d82e0051c6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "print(grad(f)(0.))" ] }, { "cell_type": "markdown", "metadata": { "id": "Usbp_gxaVVea" }, "source": [ "Here's the convenience wrapper version:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "qXnrxIfaVYCs" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x):\n", " return x / (1 + jnp.sqrt(x))\n", "\n", "f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "uUU5qRmEViK1", "outputId": "ea7dc2c4-a100-48f4-a74a-859070daf994" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "print(grad(f)(0.))" ] }, { "cell_type": "markdown", "metadata": { "id": "7J2A85wbSAmF" }, "source": [ "### Gradient clipping\n", "\n", "While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.\n", "\n", "For gradient clipping, we can use jnp.clip together with a jax.custom_vjp reverse-mode-only rule:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "8jfjSanIW_tJ" }, "outputs": [], "source": [ "from functools import partial\n", "from jax import custom_vjp\n", "\n", "@custom_vjp\n", "def clip_gradient(lo, hi, x):\n", " return x # identity function\n", "\n", "def clip_gradient_fwd(lo, hi, x):\n", " return x, (lo, hi) # save bounds as residuals\n", "\n", "def clip_gradient_bwd(res, g):\n", " lo, hi = res\n", " return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi\n", "\n", "clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "4OLU_vf8Xw2J", "outputId": "5a51ff2c-79c2-41ba-eead-53679b4eddbc" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 24, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "