{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "LqiaKasFjH82" }, "source": [ "# Custom derivative rules for JAX-transformable Python functions\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." ] }, { "cell_type": "markdown", "metadata": { "id": "9Fg3NFNY-2RY" }, "source": [ "## TL;DR" ] }, { "cell_type": "markdown", "metadata": { "id": "ZgMNRtXyWIW8" }, "source": [ "### Custom JVPs with `jax.custom_jvp`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "zXic8tr--1PK" }, "outputs": [], "source": [ "import jax.numpy as jnp\n", "from jax import custom_jvp\n", "\n", "@custom_jvp\n", "def f(x, y):\n", " return jnp.sin(x) * y\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " x, y = primals\n", " x_dot, y_dot = tangents\n", " primal_out = f(x, y)\n", " tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot\n", " return primal_out, tangent_out" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "RrNf588X_kJF", "outputId": "b962bafb-e8a3-4b0d-ddf4-202e088231c3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.7278922\n", "2.7278922\n", "-1.2484405\n", "-1.2484405\n" ] } ], "source": [ "from jax import jvp, grad\n", "\n", "print(f(2., 3.))\n", "y, y_dot = jvp(f, (2., 3.), (1., 0.))\n", "print(y)\n", "print(y_dot)\n", "print(grad(f)(2., 3.))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "1kHd3cKOWQgB" }, "outputs": [], "source": [ "# Equivalent alternative using the defjvps convenience wrapper\n", "\n", "@custom_jvp\n", "def f(x, y):\n", " return jnp.sin(x) * y\n", "\n", "f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,\n", " lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "Zn81cHeYWVOw", "outputId": "bf29b66c-897b-485e-c0a0-ee0fbd729a95" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.7278922\n", "2.7278922\n", "-1.2484405\n", "-1.2484405\n" ] } ], "source": [ "print(f(2., 3.))\n", "y, y_dot = jvp(f, (2., 3.), (1., 0.))\n", "print(y)\n", "print(y_dot)\n", "print(grad(f)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "N2DOGCREWXFj" }, "source": [ "### Custom VJPs with `jax.custom_vjp`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "35ScHqhrBwPh" }, "outputs": [], "source": [ "from jax import custom_vjp\n", "\n", "@custom_vjp\n", "def f(x, y):\n", " return jnp.sin(x) * y\n", "\n", "def f_fwd(x, y):\n", "# Returns primal output and residuals to be used in backward pass by f_bwd.\n", " return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n", "\n", "def f_bwd(res, g):\n", " cos_x, sin_x, y = res # Gets residuals computed in f_fwd\n", " return (cos_x * g * y, sin_x * g)\n", "\n", "f.defvjp(f_fwd, f_bwd)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "HpSozxKUCXgp", "outputId": "57277102-7bdb-41f0-c805-a27fcf9fb1ae" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-1.2484405\n" ] } ], "source": [ "print(grad(f)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "p5ypWA7XlZpu" }, "source": [ "## Example problems\n", "\n", "To get an idea of what problems `jax.custom_jvp` and `jax.custom_vjp` are meant to solve, let's go over a few examples. A more thorough introduction to the `jax.custom_jvp` and `jax.custom_vjp` APIs is in the next section." ] }, { "cell_type": "markdown", "metadata": { "id": "AR02eyd1GQhC" }, "source": [ "### Numerical stability\n", "\n", "One application of `jax.custom_jvp` is to improve the numerical stability of differentiation." ] }, { "cell_type": "markdown", "metadata": { "id": "GksPXslaGPaW" }, "source": [ "Say we want to write a function called `log1pexp`, which computes $x \\mapsto \\log ( 1 + e^x )$. We can write that using `jax.numpy`:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "6lWbTvs40ET-", "outputId": "8caff99e-add1-4c70-ace3-212c0c5c6f4e" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(3.0485873, dtype=float32)" ] }, "execution_count": 8, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "import jax.numpy as jnp\n", "\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", "\n", "log1pexp(3.)" ] }, { "cell_type": "markdown", "metadata": { "id": "PL36r_cD0oE8" }, "source": [ "Since it's written in terms of `jax.numpy`, it's JAX-transformable:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "XgtGKFld02UD", "outputId": "809d399d-8eca-401e-b969-810e46648571" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.0485873\n", "0.95257413\n", "[0.5 0.7310586 0.88079715]\n" ] } ], "source": [ "from jax import jit, grad, vmap\n", "\n", "print(jit(log1pexp)(3.))\n", "print(jit(grad(log1pexp))(3.))\n", "print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))" ] }, { "cell_type": "markdown", "metadata": { "id": "o56Nr3V61PKS" }, "source": [ "But there's a numerical stability problem lurking here:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "sVM6iwIO22sB", "outputId": "9c935ee8-f174-475a-ca01-fc80949199e5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nan\n" ] } ], "source": [ "print(grad(log1pexp)(100.))" ] }, { "cell_type": "markdown", "metadata": { "id": "Zu9sR2I73wuO" }, "source": [ "That doesn't seem right! After all, the derivative of $x \\mapsto \\log (1 + e^x)$ is $x \\mapsto \\frac{e^x}{1 + e^x}$, and so for large values of $x$ we'd expect the value to be about 1.\n", "\n", "We can get a bit more insight into what's going on by looking at the jaxpr for the gradient computation:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "dO6uZlYR4TVp", "outputId": "61e06b1e-14cd-4030-f330-a949be185df8" }, "outputs": [ { "data": { "text/plain": [ "{ lambda ; a.\n", " let b = exp a\n", " c = add b 1.0\n", " _ = log c\n", " d = div 1.0 c\n", " e = mul d b\n", " in (e,) }" ] }, "execution_count": 11, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "from jax import make_jaxpr\n", "\n", "make_jaxpr(grad(log1pexp))(100.)" ] }, { "cell_type": "markdown", "metadata": { "id": "52HR5EW26PEt" }, "source": [ "Stepping through how the jaxpr would be evaluated, we can see that the last line would involve multiplying values that floating point math will round to 0 and $\\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)` for large `x`, which effectively turns into `0. * jnp.inf`.\n", "\n", "Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n", "\n", "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n", "\n", "This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...).\n", "\n", "Here's a solution using `jax.custom_jvp`:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "XQt6MAuTJewG" }, "outputs": [], "source": [ "from jax import custom_jvp\n", "\n", "@custom_jvp\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", "\n", "@log1pexp.defjvp\n", "def log1pexp_jvp(primals, tangents):\n", " x, = primals\n", " x_dot, = tangents\n", " ans = log1pexp(x)\n", " ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot\n", " return ans, ans_dot" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "rhiMHulfKBIF", "outputId": "883bc4d2-3a1b-48d3-b205-c500f77d229c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "print(grad(log1pexp)(100.))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "9cLDuAo6KGUu", "outputId": "59984494-6124-4540-84fd-608ad4fc6bc6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.0485873\n", "0.95257413\n", "[0.5 0.7310586 0.8807971]\n" ] } ], "source": [ "print(jit(log1pexp)(3.))\n", "print(jit(grad(log1pexp))(3.))\n", "print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))" ] }, { "cell_type": "markdown", "metadata": { "id": "9sVUGbGkUOqO" }, "source": [ "Here's a `defjvps` convenience wrapper to express the same thing:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "xfQTp8F7USEM" }, "outputs": [], "source": [ "@custom_jvp\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", "\n", "log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "dtdh-PLaUsvw", "outputId": "aa36aec6-15af-4397-fc55-8b9fb7e607d8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n", "3.0485873\n", "0.95257413\n", "[0.5 0.7310586 0.8807971]\n" ] } ], "source": [ "print(grad(log1pexp)(100.))\n", "print(jit(log1pexp)(3.))\n", "print(jit(grad(log1pexp))(3.))\n", "print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))" ] }, { "cell_type": "markdown", "metadata": { "id": "V9tHAfrSF1N-" }, "source": [ "### Enforcing a differentiation convention\n", "\n", "A related application is to enforce a differentiation convention, perhaps at a boundary." ] }, { "cell_type": "markdown", "metadata": { "id": "l_6tdb-QGK-H" }, "source": [ "Consider the function $f : \\mathbb{R}_+ \\mapsto \\mathbb{R}_+$ with $f(x) = \\frac{x}{1 + \\sqrt{x}}$, where we take $\\mathbb{R}_+ = [0, \\infty)$. We might implement $f$ as a program like this:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "AfF5P7x_GaSe" }, "outputs": [], "source": [ "def f(x):\n", " return x / (1 + jnp.sqrt(x))" ] }, { "cell_type": "markdown", "metadata": { "id": "BVcEkF3ZGgv1" }, "source": [ "As a mathematical function on $\\mathbb{R}$ (the full real line), $f$ is not differentiable at zero (because the limit defining the derivative doesn't exist from the left). Correspondingly, autodiff produces a `nan` value:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "piI0u5MiHhQh", "outputId": "c045308f-2f3b-4c22-ebb2-b9ee582b4d25" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nan\n" ] } ], "source": [ "print(grad(f)(0.))" ] }, { "cell_type": "markdown", "metadata": { "id": "IP0H2b7ZHkzD" }, "source": [ "But mathematically if we think of $f$ as a function on $\\mathbb{R}_+$ then it is differentiable at 0 [Rudin's Principles of Mathematical Analysis Definition 5.1, or Tao's Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function `grad(f)` to return at `0.0`, namely `1.0`. By default, JAX's machinery for differentiation assumes all functions are defined over $\\mathbb{R}$ and thus doesn't produce `1.0` here.\n", "\n", "We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function $x \\mapsto \\frac{\\sqrt{x} + 2}{2(\\sqrt{x} + 1)^2}$ on $\\mathbb{R}_+$," ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "ksHmCkcSKQJr" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x):\n", " return x / (1 + jnp.sqrt(x))\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " x, = primals\n", " x_dot, = tangents\n", " ans = f(x)\n", " ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot\n", " return ans, ans_dot" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "Gsh9ZvMTKi1O", "outputId": "a3076175-6542-4210-ce4a-d0d82e0051c6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "print(grad(f)(0.))" ] }, { "cell_type": "markdown", "metadata": { "id": "Usbp_gxaVVea" }, "source": [ "Here's the convenience wrapper version:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "qXnrxIfaVYCs" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x):\n", " return x / (1 + jnp.sqrt(x))\n", "\n", "f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "uUU5qRmEViK1", "outputId": "ea7dc2c4-a100-48f4-a74a-859070daf994" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "print(grad(f)(0.))" ] }, { "cell_type": "markdown", "metadata": { "id": "7J2A85wbSAmF" }, "source": [ "### Gradient clipping\n", "\n", "While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.\n", "\n", "For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "8jfjSanIW_tJ" }, "outputs": [], "source": [ "from functools import partial\n", "from jax import custom_vjp\n", "\n", "@custom_vjp\n", "def clip_gradient(lo, hi, x):\n", " return x # identity function\n", "\n", "def clip_gradient_fwd(lo, hi, x):\n", " return x, (lo, hi) # save bounds as residuals\n", "\n", "def clip_gradient_bwd(res, g):\n", " lo, hi = res\n", " return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi\n", "\n", "clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "4OLU_vf8Xw2J", "outputId": "5a51ff2c-79c2-41ba-eead-53679b4eddbc" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 24, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOydd3ic1ZX/P3fUe2+WZEuWZau6Wy5gDC7CFdPBpJiQhDSS3SSbhGwK2STkR7LZJWWz2QCBEDoYg23cMcY2Bhe5qtqS5aYuS1bv0v39cUcgjGRpNOWdV/N+nmcezbz1i5l3zr3nnHuOkFJiYGBgYOC6mLQWYGBgYGCgLYYhMDAwMHBxDENgYGBg4OIYhsDAwMDAxTEMgYGBgYGL4661gNEQHh4uExIStJZhYGBgoCuOHTt2RUoZce12XRqChIQEcnJytJZhYGBgoCuEEBcH2264hgwMDAxcHMMQGBgYGLg4hiEwMDAwcHEMQ2BgYGDg4hiGwMDAwMDFsYkhEEI8K4SoEULkDbFfCCH+JIQoEUKcFkLMHLBvvRCi2Pxabws9BgYGBgYjx1Yzgn8Ay6+zfwWQbH49DPwVQAgRCjwGzAWygMeEECE20mRgYGBgMAJsso5ASrlfCJFwnUPWAv+Uqub1ISFEsBAiBrgZ2C2lrAcQQuxGGZRXbKHrM5x6FRrLwD8Kxs2AyDQwObd3rLu3j/yKJvIrGmlo60YIGBfkw8zxIYwP89VanoGz01gO5Tlw9SL0dIJvKERMgbgscPfUWt11kVJyrraV3PIGqho76entIyrQm9SYQNLHBWIyCa0ljhkctaAsFrg84HOZedtQ2z+DEOJh1GyC8ePHj05F3kYo3vnJ54BxMPOLMPdr6gFxIqqbOnh6fylvnSinrrVr0GOSI/354vwJ3DM7Hm8PNwcrNHBa+noh7004+gxcPjz4MZ7+kHEXzH8EIiY7Vt8wtHf18uKhi7x69BLnalsHPSbc34v75sSxfkECkQHeDlY49hC2akxjnhG8I6XMGGTfO8ATUsoPzJ/3AD9CzQi8pZS/Nm//GdAupfz99e41e/ZsOeqVxd3t0FQBl4+oh6XkXfAOhFt+AnO+qvkMobu3j//de46/7iuhp1eyLC2KlZkxTI8PJjLQi94+yaX6Ng6dq2PjiXJOlzUSF+LDL9akszQtSlPtBk5AWQ5s/jbUFED4FJi+DhJvgrBk8PCBlhqoPAlntkHuBujtgqyvwS3/rp4Djdl6upL/2JJPTXMncxJCWDs9lqzEUOJDfDGZoKqxg+OXrrItt4p3C6vx9XDju8sm8+CCBNzdnHt27wwIIY5JKWd/ZruDDMHfgPellK+YP59BGYGbgZullF8b7LihsMoQXEtNIez4MZTuhaQlcOdT4Bdum2tbSEVDO9948RinyhpZNTWGH92acl33j5SSj87V8Yst+ZytbmH9/An8+6pUvNyN2YHLISXs/z28/xs1083+FaTdfv2BTesV2Ps45DwHoRPh3n9C9GceX4fQ0d3Lv2/MZeOJcqbGBfHTVWlkJV5/ll5a28Kv3ilg75la5k0M5U/rZhizg2HQ2hCsAh4BVqICw3+SUmaZg8XHgP4souPArP6YwVDY1BCAeohy/g47fwKBsfCFjRCSYLvrj4C88kYe+sdR2rt6+e3dU1mZGTPic7t6+vjdjiKe+eA8cxNDeWb9bAK8Peyo1sCp6OlUs4DTr0HmPbDqv8A7aOTnXzgIGx6Czia47wWYtNR+WgehrqWTr/4zh+OXGviXJcl8e/GkEY/upZRsPF7OT97OJcTXkxe+PJdJkf52Vqxf7GoIhBCvoEb34UA1KhPIA0BK+X9CCAH8DyoQ3AZ8SUqZYz73IeDfzZd6XEr53HD3s7kh6OfSYXj5XnDzhAffUUE1B3D80lW+8Mxhgn09efbBOUyJDhjVdd4+Uc6/vXGK1JhA/vlQFiF+zh0MNLABPZ3w6uegZDcs/hks/D6IUQRRm6vhpbugpgjufhbSbrO91kG40tLJfX/7iLKr7fzhvumssGAANJD8ikbWP3sUKSXPP5RFRqwFhtCFsPuMwJHYzRAA1J6Bf6wGNw94aCcEx9vnPmbyKxpZ99QhQvw8ef1r84kKtG5q+15RNV9/8Tjp4wJ5+Svz8PE03ERjlt4e2PAgFG6BNX+EWQ9ad732BjUQKj8Gn3sDkhbbQuWQNLR1se7pw5y/0sLzX8pi7sQwq653/korn3/mMO3dvWz8xgISwv1spHTsMJQhMKIr1xIxRbmGOlvghTug/ardblXZ2M6Dzx3F38udl74y12ojALA4JYo/3T+DU5cbeOTl4/T09tlAqYFTsuNRZQSW/9Z6IwDgE6wMQPgUeO0LUHnK+msOQXdvH19/8Rjnalp4+ouzrTYCAInhfrzw5SyklKx/7ghXWjptoNQ1MAzBYERnwgOvwtULsPFr0Gf7H9OO7l6+/sIx2jp7+MdDWcSF2G5NwPKMaP7jtnT2FNXwX7vP2uy6Bk7EsX/A0adhwbdh3tdtd13vIPj8BvAOhlcegLbrhutGzS+3FHCotJ7f3p3JwuTP9EkZNRMj/Pn7g3Ooburgmy8ZA6GRYhiCoZiwAJb/P7XuYP/vbH75xzblc6qskSfvm87kqNHFBK7HF+YnsC4rnr++f47dBdU2v76BhpQfh63/prLclv6H7a8fOA7ufxFaa2DjV20+ENpwrIwXDl3kazdN5I4ZcTa9NsDM8SH85o5Mjpyv5/e7jIHQSDAMwfWY8xWYtg72/RYuHbLZZXfkVfJazmW+eXMS2enRNrvutTy2Jp2M2EC+//pJKhvb7XYfAwfS1ap+nP0j4e6/g8lOMaBxM2DFb9U6m4NP2uyyl+raeGxTHnMTQ/nh8hSbXfda7pwZxwNzx/N/+86x90yN3e4zVjAMwfUQAlb+JwTFw1tfV3EDK6lp6uDHG3PJjA3iu8vsu6LT28ON/1k3k+5eyaNv5qLHxACDa9j1U6g7B3f8H/jYuSzXrC+ptQh7/x9UF1h9uZ7ePv71tROYTIL/vm86bnYuEfHz1WlMiQrg0TdP09jWbdd76R3DEAyHVwDc/lcVL9j9M6suJaXkR2+epr27lyfvm46HA1ZCJoT78eiKFPadreX1nMvDn2DgvJS8CznPqrhA4k32v58Qn6xJePsb0Gvdj+lTB0o5fqmBX9+eQWywj41EDo23hxu/v2caV1q6+OU71huysYxhCEZCwg0w/1vqIbz44agvsz2vir1navnBrSkOXfTyhXkTmDcxlF+9U0hVY4fD7mtgQ7rb4Z3vqVIRi3/quPv6hcPq/1ZlKQ7+cdSXuVTXxh/fLWZ5ejRrpw9aTswuZMYF8Y1FSbx5vIz3ioxY2VAYhmCk3PLvykW09d9U/raFNHd08x9b8kmLCWT9/Al2EDg0JpPgd3dNo7u3j99sK3TovQ1sxP7fQ8NFWP0kuHs59t5pa9Vr/++h4ZLFp0sp+fnmPNxNgsduS7ODwOvz7SWTmBTpzy82F9DR3evw++sBwxCMFE8/lUVUkw9HnrL49Cd3F1PT3Mnjd2RoUhxrfJgvX1uUxOZTFRwqrXP4/Q2soPaMGo1PWweJC7XRkP24+rvzJxafuiOvivfP1PK97CnEBNnfJXQtXu5u/GJNOpfq23h6f6nD768HDENgCSmrYdIy2PsbtSR/hJTUtPD8Rxe4f854ZozXru/ONxYlERvsw2Ob8o38aj2x66fg4QvLfqWdhuB4Vb6icDOUvj/i0zp7enl8WyEp0QEOnwkP5MbkcJanR/OX90uoaDAy6K7FMASWIIRKqetpt2htwe92FOHtbuL72drWfffxdONnq1M5U93MK0eNwLEuKN0Hxbvgpu+Dv+0WXo2KBd9WxRh3/Fj1PBgBL3x0kbKr7fx0VZrmZaJ/sioVKeGJ7UWa6nBGDENgKWFJajn/sX+oNL5hyLlQz66Car6+KIlwfwf7dgfh1vRoZk8I4c97imnvMvylTk1fH+z+uYpNZX1NazXg4Q1LHlO9DnLfGPbwxrZu/vxeCTdNjuDGZG1Kuw8kPtSXh25MZPOpCgoqmrSW41QYhmA0LPoRuHnBnl9e9zApJb/ZVkhkgBdfXpjoIHHXRwjBD5enUNPcyT8+vKC1HIPrkb9RZess/qn6EXYG0m6H6KnKPdozeOe8fv73/RKaOrp51I4Lxyzl6zclEejtzu93ndFailNhGILR4B+ppskFb6tKjUOwq6Ca45ca+N6yyfh6Oqor6PBkJYZy85QI/vp+ibHQxlnp7Yb3fgVRmZB5r9ZqPsFkUrOChotw/PkhD6tq7OC5Dy9wx4xY0sZp3/msnyBfD762KIn3imrIuWCfOkp6xDAEo2XBI2pl577/HHS3lJI/7SkmIcyXu2fZvp6Ktfzg1ik0dfTwt/3Du7cMNOD062oR4+KfaN4+9TNMWgITboB9v1MlLwbh//ado7dP8t2lztUPGeBLNyQQ7u/F73acMVbbm3Gyb5iO8AqAed+Cs9sHLdf7XlEN+RVNfPOWkXdbciTp44JYNTWG5z+8YMwKnI3eHjjwe+WCmbxcazWfRQhY8nNVlO7YPz6zu6a5g1eOXOLOGbHEh9quqq6t8PV059uLJ3HkQj0fGanUgI0MgRBiuRDijBCiRAjx6CD7nxRCnDS/zgohGgbs6x2wb7Mt9DiMuQ+DVxDs//SsQErJn94rIS7EhztmOG4VpaU8csskWrt6ef6jC1pLMRhI3ptQX6piUaPpNuYIxs+DhIXw4Z9Vl7QBPL2/lJ4+ySOLJ2kkbnjumxNPuL8Xf33fmBGDDQyBEMIN+AuwAkgD1gkhPrV8UEr5XSnldCnldODPwMYBu9v790kpHdMfz1Z4B6la8IVboDr/480Hiq9w6nID37x5kkPqCY2W1JhAlqRE8uzB87R2Wr5a2sAO9PWqgUVkOkxZqbWa67Pwe9BcCade/XjTlZZOXjx0ibXTxzEhzHk7hHl7uPGVhYkfP6uuji1+pbKAEillqZSyC3gVWHud49cBr9jgvs7B3K+Dpz8c+K+PN/3PeyXEBHlz1yznnQ30881bJtHQ1s0rRywvHWBgBwo3Q10x3PRvzhcbuJaJt6hy1R88+XHZlecOnqejp5dv3eK8s4F+Pjd3PIHe7vzv+yVaS9EcW3zTYoGBq5PKzNs+gxBiApAIvDdgs7cQIkcIcUgIcftQNxFCPGw+Lqe2ttYGsm2EbyjMfgjy34aGS5y83MCRC/V8+cZEvNydv1/wrAkhzJ8YxlP7S+nsMdYVaM5Hf4HQiaq2j7MjhFptfPU8FLxNW1cPLx66xK1p0SRFOK6o4mgJ8PbgwQUJ7Myvpri6WWs5muLoIcf9wAYp5cBfnAnmZsoPAH8QQiQNdqKU8ikp5Wwp5eyICI1XWF7LXPNin8N/4+8fnCfAy5375ti36b0t+eYtSdQ0d7L5ZIXWUlyby0eg7CjM+6b9Gs7YmimrVI/jD//EmzmXaWzv5qs3OceamZHw4A2J+Hi48ZSL1yCyhSEoBwb+6sWZtw3G/VzjFpJSlpv/lgLvAzNsoMmxBMVB+u30HXuefbml3J8VT4C3h9aqRsyNk8KZHOXPcwcvGOl0WvLR/6hewdMf0FrJyDGZVJys8hRH9m9jenwwMzWsp2UpoX6e3Dkzlk2nKqhz4Wb3tjAER4FkIUSiEMIT9WP/mewfIUQKEAJ8NGBbiBDCy/w+HLgB0GcHiXnfwtTVzD2m91m/IEFrNRYhhOBLNyRSUNnEkfPGIhtNuHpBJR3M/pKqdKsnpt5Pt2cQK1rf5isLExHOmuk0BA8uSKCrp8+l42RWGwIpZQ/wCLATKARel1LmCyF+KYQYmAV0P/Cq/PSQMxXIEUKcAvYCT0gpdWkIWiKmcVxO4Rs+u4kL0r6mkKXcPj2WYF8Pnjt4QWsprsmh/wNhgqyHtVZiOZ6+bPPI5la3HJbHXr/shDOSHBXAwuRwXjh0kW4XrcprkxiBlHKblHKylDJJSvm4edvPpZSbBxzzCynlo9ec96GUMlNKOc389++20KMFb+Rc5m/dKwjvroSirVrLsRgfTzceyBrProIqLte3aS3HtehoghMvQMZdEDhOazUWk1feyG/rFiKEwP2YPh/hL92QQHVTJ9tyK7WWoglOnp+mD6SUvHDoIldil0DQeDj6tNaSRsUX5k9ACME/P7qgtRTX4vRr0NXiHBVGR8GLhy5y1SOKnsmrVf2hIcpOODM3T44kMdzPZWfEhiGwAYdK6ymtbWXdvIkwaz2c3w9XirWWZTExQT6syIjm1aOXjRLVjkJKyHkOYqZB7Eyt1VhMU0c3m05WcNu0cXje8E3oaFSGTWeYTIL18ydw8nIDJ11wgZlhCGzAS4cvEujtzuqpMTDzi2ByVw+3Dvn8vAk0d/Sw1UWnyA7n8hHV/nT2l523nMR1ePtEOe3dvXxu3niInwtRGYPWH9IDd82Kw8fDjVcOu17Q2DAEVlLb3MnO/CrunhWPt4ebKlGdugZOvgTd+muJNzcxlIkRfi6dQeFQcp4Fr0AVH9AZUkpePHSRqXFBTI0LVoZs1oOqCGPFCa3lWUyAtwe3TRvH5lMVNHe4ViFGwxBYyRvHLtPdK3lg7vhPNs5+CDoa1GpjnSGE4IGs8Ry7eJUzVa692tLutNVD/lsw9T7wcv6VuNeSc/EqZ6tb+NzA737mPeDuo9tZwbq542nv7mWTiy2uNAyBFfT1SV4+fIl5E0OZFDngQU5YCGHJkKPPDIo7Z8bh6WYyZgX25uRL0NupBg465KVDFwnwdmfNtAGZTj7BkHEn5G6ATv0NJKbFBZEaE8jLhy+51OJKwxBYwf7iWsqutvO5uRM+vUMI9XCXHYWqXG3EWUGonyfLM6LZeLyMjm4jaGwX+vpUHGn8fIhKG/54J6O+tYttuVXcNTPus933Zj2osqDy3tREmzWoGXE8BZVN5JY3ai3HYRiGwApePXKZMD9Pbk2P/uzOafeDmyeceMnxwmzAuqzxNHX0sPW0ETS2C5c+hPpz6kdTh7x1opyu3j7WZY3/7M64ORCZplv30NoZsXh7uNaM2DAEo6S+tYs9RdXcMSMWT/dB/hl9Q1U9+dzXh23y7YzMmxjKxHA/Xnahh8GhnHwZPAMgVV8tOPrZcKyMaXFBTIkO+OzO/qBxxQmoOOlwbdYS6O3Bmqnj2HSyghYX6dNhGIJRsvlkOd29kruu1494+uegrQ7O7nCcMBshhOC+OfEcu3iV0toWreWMLTpbVCJBxh3g6XytHIcjv6KRwsqm6/finnovuHvDiRcdJ8yG3J81nrauXraedo2gsWEIRsmG42VkxAaSGhM49EFJi8E/WgUFdcgdM2IxCdh4fKhisgajonAzdLeqgYIO2XCsDE83E7dNu07jJZ8QNSPO26DLGfHM8cFMDPfjTRf57huGYBQUVjaRV97E3TOvMyICcHNXsYLi3dBc7RhxNiQy0JuFyRG8daKcvj7XyaCwOydfhtAktQBLZ3T19LHpZAXL0qMI8h2m1Pq0ddB+FYp3OUacDRFCcNesOI6cr3eJ2luGIRgFbx4rw8NNcNv0EbSinPF5kL1w+tXhj3VC7pwZS3lDO4fO12ktZWxQfx4uHFA9B3S4kvi9ohrqW7uu7xbqJ2kx+EXCKX12pr19RizCRWbEhiGwkO7ePt4+Wc6SlChC/TyHPyE8GeKyVPaQDvOSb02PJsDLnTePjf2HwSGcehUQaqaoQzYcKyMywIuFk8KHP9jNXcUKzu6EVv0NJGKDfZg/MYyNJ8rG/JoCwxBYyL4ztVxpGeGIqJ8Zn4MrZ6D8uP2E2QlvDzdWTY1he14lbV2ukUFhN/r64NTLMPFm1dVOZ9Q2d7L3TA13zIzF3W2EPx3T1kFfN+RvtK84O3HnzDgu1rVx7OJVraXYFcMQWMiGY2WE+3uyaIoFfZPT7wA3L5VKqkPunBlHW1cvO/KqtJaiby4ehIZLug0SbzpZTm+fHD42NpDoDIjKVHERHbI8IxofD7cxHzS2iSEQQiwXQpwRQpQIIR4dZP+DQohaIcRJ8+srA/atF0IUm1/rbaHHXjS2dbOnqJq102PxGOmICMA7CCZnq5WWvfobVc9JCCE+1Ic3j5dpLUXfnH4NPP0hZZXWSkbF2yfLmRoXRHLUIGsHrsf0dVBxHGrP2EeYHfH3cmdFRjTvnK4Y06vsrTYEQgg34C/ACiANWCeEGGzN/GtSyunm1zPmc0OBx4C5QBbwmBDCaTtfb8+rpLtXcvtIgsTXknkvtNbC+X22F2ZnhBDcOSOOD8/VUdGgv4qqTkFPJxRsVpVpdbh2oKSmhbzyJtaO6rt/Dwg3c3xEf9w5M47mjh7eLdRf5t9IscWMIAsokVKWSim7gFeBtSM891Zgt5SyXkp5FdgNLLeBJruw6WQFieF+ZMReZ+3AUCRng1cQ5L5he2EO4M6ZsUgJW065xgIbm1O8GzobIeNurZWMis2nKhAC1XPDUvwjYdISOP26ipPojPlJYcQEefPWGHYP2cIQxAKXB3wuM2+7lruEEKeFEBuEEPEWnosQ4mEhRI4QIqe2ttYGsi2juqmDQ+fruG3aOMRo0v48vCFtDRRu0WWfgglhfkyLD2azYQhGR94G8A2HiYu0VmIxUkq2nKpg/sQwogK9R3eRzHugqQzKjthWnANwMwlWT41hf3EtDW36Wxw3EhwVLN4CJEgpp6JG/c9begEp5VNSytlSytkRERYEam3EllMVSAm3TbeiuXjmvaoq45ntthPmQNZMjSG/oolzRskJy+hsVv/P0+8At2EWYTkhueWNnL/Sym3TrPjuT1mhSk7osCIpwG3TYunulezMH5sJE7YwBOVA/IDPceZtHyOlrJNSdpo/PgPMGum5zsKWUxWkjwskKcKKBiIJN0JAjG7dQ6unjkMIwz1kMUXboKcDMnXqFjpZgYebYEXGKNxC/XgFwOTlqhGPDhMmMmIDSQjzHbMzYlsYgqNAshAiUQjhCdwPbB54gBBi4DfoNqDQ/H4nkC2ECDEHibPN25yKC1daOVXWyFprZgMAJjfVkrB4t+pOpTOig7zJSgg1z47G9gIbm5L7BgSNVwsLdUZvn2TL6QoWTY4cvqTEcGTcpRImLhywjTgHIoRgzbRxfHSujprmDq3l2ByrDYGUsgd4BPUDXgi8LqXMF0L8UgjRX2P3O0KIfCHEKeA7wIPmc+uBX6GMyVHgl+ZtTkX/KGD1VCsNAShfaV83FGyy/loacNv0cZyrbaWwUn/dpzSh9Qqce0917TLpb9nOkfP1VDd1Wj8IAkhepkpv69Y9NI4+Cdtzx557yCbfTCnlNinlZCllkpTycfO2n0spN5vf/1hKmS6lnCalvEVKWTTg3GellJPMr+dsoceWSCnZdLKcrIRQxgX7WH/BmGmqjaVOH4YVGTG4m8SYnSLbnIK3Va2pzHu0VjIqNp8qx9fTjaWpUdZfzMNHraEo3KzSaXVGclQAKdEBY/K7r78hioMpqGziXG2rdUHigQihgoYXD0JLjW2u6UBC/Ty5MTnccA+NlNw3ISIFotK1VmIxXT19bMutIjstCh9PN9tcNPNu6GhUsyQdsmbaOI5dvErZ1bFVkdQwBMOw+VQF7ibBykwrAmXXkn47yD41MtIha6aOo7yhneOXGrSW4tw0lquWlBl367LS6IHiWhrbu203CAJVZ8knRLcz4jVm9/BYa+FqGILrIKVkW24lCyaFj6zS6EiJTIPwyapLlQ7JTo/C091kZA8NR7+hz7hTWx2jZFtuFYHe7tw4yYbp2m4ekLZWZVJ16W9UPT7Ml2nxwWwZY53LDENwHfIrmrhc386qzEGa01uDEJB2u27dQwHeHiyeEsnW3Ep6jYY1Q1OwSRVcC0vSWonFdPX0sbugimVp0YP35LaGjLtVh7az+l1Pk1feNKZauBqG4Dpsy63EzSRYlmZjQwAqTqBj99CqqTHUNneO+fK8o6apEi4dUqNfHXLw3BWaOnpYaetBEMCEBaphjU4z5/qzB7fljh33kGEIhqDfLTR/Ypht3UL9RKbq2j10S0okXu6mMfUw2JTCLYDUrSHYnluJv5c7NyaPoAGNpZjcVPG94t26dA9FB3kzc3ww28dQWXbDEAxBUVUzF+raWGGPERHoPnvI38udmyZHsCOvyuhnPBgFmyAiFSIma63EYrp7+9hVUM3S1Ei83G2ULXQtaWuhuw1K3rXP9e3MykxVbuVSnf4M2WAYhmAItudWYhKqVaPdSNN39tDKzGiqmjo4WWZkD32Klhpl4NNv11rJqDhUWkdDW7dtM+WuZcIN4BOqW/dQ/+/C9ryxMSM2DMEQbMurYm5iGOH+Xva7SWQqhE/RrXtoSWoUHm6C7YZ76NMUbkbPbqFtuZX4ebpx02Q7Fnd0c4fU1aqfcbf+SjbEh/oyNS6IbWPEPWQYgkE4W91MSU2LfQJlAxFCjRovHoRm/TW9CPT24MZJ4WzLrTIWlw2kYJOK/0SkaK3EYnp6+9iZX83i1Ci8PezkFuondS10NUPpXvvex04sz4jm1OUGysdAsybDEAzCttxKhL3dQv3o3D20IjOG8oZ28sqbtJbiHLRegQsfqNmADheRHTlfT31rFyszHPDdT7xJtXEt0Ol331yNdSz08jYMwSBsz61izoRQIkfbhMMS+rOHdGoIstOicDcJto0RX6nVFL2jDHuaPuMDW3Mr8fFw4+Ypkfa/mbsnTFkJZ7ZCj/4aviSG+5ESHcCOMfDdNwzBNZTUtHCmutl+2ULXIgSkrIYLB3VZmjrY15P5SWFsz6003EOg4j2hSbqsLdTbpxqvLE6JtF1toeFIW6tqD53f75j72ZiVmTHkXLxKTZP+4hwDMQzBNfRbd6uacFhK6hpVofLsDsfd04asyIjhQl0bRVUuXpq6rV79oOnULXT0Qj1XWrocNwgCmHiLKk1dqM/soZWZ0UiJ7juXGYbgGrblVjFrQgjRQQ5wC/UzbgYExpkXIemP7PQoTAIje6hoqzLoOs0W2p5biZe7iVsc4Rbqx8MbJt8Khe/osnPZpMgAJkX6s03nPQpsYgiEEMuFEGeEECVCiEcH2f89IUSBuXn9HiHEhAH7eoUQJ80vTR3lF+taKbq/0P0AACAASURBVKhsYoUjAmUDEULVaT/3HnS1OvbeNiDc34usxNAxk0o3ago3Q/AE1XNCZ/T1SXbkV3HzlAj8vNwde/O0tdBer7LndMjKjGgOn6+jrkV/PRb6sdoQCCHcgL8AK4A0YJ0QIu2aw04As83N6zcAvxuwr11KOd38ug0N2ZWvUjgdki10LalrVF9bHa+0LKlpobjaRd1Dnc1Q+r76/6hDt9Dp8kaqmzq1+e5PWgoevrpdXLY8I4Y+CbsK9JcC3o8tZgRZQImUslRK2QW8Cnxqbiyl3Cul7F+LfQjVpN7p2FVQRVpMIPGhvo6/+fj5aqWlXt1D5sJ8en4YrKLkXejtUjM7HbIrvwo3k2BxigPdQv14+ipjULQV+vocf38rSY0JYEKYr67jBLYwBLHA5QGfy8zbhuLLwMD6s95CiBwhxCEhxJA5d0KIh83H5dTW1lqneBCutHSSc/Eq2ek2aMk3GtzcVSrd2V26TKWLDvJmWnwwu3T8MFhF0VbwDYP4uVorGRW7CqqZmxhKsK8dCiyOhJTV0FIFFce1ub8VCCHIToviw5I6mju6tZYzKhwaLBZCfB6YDfzngM0TpJSzgQeAPwghBi3eLqV8Sko5W0o5OyLC9kvf9xRWI+UnI1tNSF0DnY1wQZ+pdNlpUZwqa6SyUf8rLS2ip0sZ8CkrVGVNnVFa20JJTQvZaRoNggAmZ4PJXa3D0CHZ6dF09fax76ztB6mOwBaGoByIH/A5zrztUwghlgI/AW6TUn4cVZFSlpv/lgLvAzNsoMliduVXExvsQ2pMgBa3V0y8GTz9dese6vcvv+tq7qGLHygDnrJaayWjYrf5/9cyLeID/fiEQMKNKntIh8wcH0K4v+fHcUa9YQtDcBRIFkIkCiE8gfuBT2X/CCFmAH9DGYGaAdtDhBBe5vfhwA1AgQ00WURrZw8HSq6QnR6F0DLQ5+Ft9pVug75e7XSMkkmR/kyM8GOnTh+GUVO0VQU7J96stZJRsaugmozYQGKDfbQVkrIa6oqh9qy2OkaBm0mwNDWKvUU1dPXoL85htSGQUvYAjwA7gULgdSllvhDil0KI/iyg/wT8gTeuSRNNBXKEEKeAvcATUkqHG4L9Z2vp6unT1i3UT+oaaK2BsqNaKxkV2WnRHCqto7FNn75Si5FSGe5JS8BD4x/SUVDT3MHxS1dZluoE3/0pK9Rf3bqHomju7OGj0jqtpViMTWIEUsptUsrJUsokKeXj5m0/l1JuNr9fKqWMujZNVEr5oZQyU0o5zfz377bQYym7C6oJ9vVgTkKIFrf/NMnZ4OapW/dQdnoUPX2SvWf012xnVFScgOYK3bqF9hTWqNiYVkkSAwmKU4sri7ZqrWRULEgKx9fTTZcJEy6/sri7t489RTUsSYnC3c0J/jm8AyFxkTIEOqzdMz0umIgAL3YV6O9hGBVFW0G4KQOuQ3blVxEf6kNKtIaxsYGkrILyHNXzWWd4e7hx85QIdhdU665rnxP88mnL0fP1NLZ3O8eIqJ/UNdBwEarztFZiMSaTYFlaFO+fqaWjW39xDosp2qqasfuGaq3EYlo6ezhYUkd2WrS2sbGBpKxRf89s01bHKMlOi6amuZNTOuva5/KGYFdBNd4eJm5KtmM3JkuZshKESb/uobQo2rp6+fDcFa2l2Je6c1BbqFu30P6ztXT19mmbNnotEVNU9VadxglumRKJu0nobmGlSxsCKSW78qtYmBzhuLK7I8E/AuLn6TaVbn5SGP5e7rpNpRsx/b7slJXa6hglu/KrCPXzZNYEJ4iN9dNfd+v8fmjX16gaIMjXg3kTw3QXJ3BpQ5Bf0URFY4dzjYj6SV0NNflQf15rJRbj5e7GLSmR7C6opldnvlKLKNoK0VMheLzWSizmk9hYpHPExgaSshr6enRbdys7PYpzta2U1LRoLWXEONk3wLHsyq/CJFQTdqejv2aNTjMostOiqGvt4vilq1pLsQ8tNXD5sG7dQodL62nu6GGZMw6C4maDX6Ru3UNLzb8nu3XkHnJtQ1BQzeyEUEL9NKqvcj1CEiAqU7eG4OYpEXi4Cd1NkUfMme2A1K9bqKAKbw8TC50pNtaPyU2tKSjeDT36K+08LtiHqXFBusqcc1lDcMncUcsp3UL9pKyCSx9Bi/7qlwR4e7AgKZxdBdVjs4Vl0VblEorK0FqJxUgp2V1QzU3OFhsbSOoa6GrRbQvL7LQoTlxqoFonLSxd1hD0W2unWE08FCmrAAlntw97qDOSnR7Fxbo2zlbrx1c6IjpbVO+BlNW67D2QV95EZWMH2VrWFhqOxJt0XXer/99WL+4hFzYE1aREBzA+TIPeAyMlOhOCxuvWPbTM7Csdc+6hc3ugt1O/vQcKzLExLXoPjBR3L0heptYT6LDuVnKkPwlhvrpJI3VJQ1DX0knOhXrnHhGBGm2mroZze1UHLJ0RGejNjPHBunkYRkzRVtVEKH6e1kpGxa78arISQwlxxtjYQFJWQ2stlOVorcRihBBkp0fz0bkrNOmgR4FLGoI9RTX0SZw7PtBPyio1+izZo7WSUZGdFk1ueSPlDWOkR0FvN5zdoYKZbg7u7WsDLlxp5Ux1s3O7RPtJXgYmDyjSqXsoLYruXsn7Z5w/xueShqC/90D6uECtpQxP/Dw1+tSpe6i/dMeY6VFw8SB0NOrWLfRx7wE9DIK8gyBxofru6zDhYIa5R4EeWli6nCFo7+rlg5JalqVp3HtgpHzcwnKnGo3qjKQIf5Ii/HSVSnddiraCuw9MvEVrJaNiV0EVqVr15R4NKaugvhRqz2itxGL6exTsO1NLZ49zxzlczhDsL66lo9vJ6qsMR8oqcwvLA1orGRXZ6dEcKq3Xf48CKZUhSFqsGq7rjCstnRy7eFVf3/0p5nUaZ/Q7I27p7OGjc87do8DlDMGu/GqCfDyYk6ijapFJt6gOWHp1D6VF0dsnee+Mzt1DlSehqVy3bqH3Cs2xMWeqtDscgeMgdpZuv/v9PQqcPY3UJoZACLFcCHFGCFEihHh0kP1eQojXzPsPCyESBuz7sXn7GSHErbbQMxQ9vX3sKapmSUokHs5WX+V6ePioDlhF26BPf23wpsUFExng5fQPw7AUbVVVYScv11rJqNhVUEVssA9pMTqIjQ1kykooPwZNFVorsRhvDzcWTXb+HgVW/xoKIdyAvwArgDRgnRAi7ZrDvgxclVJOAp4Efms+Nw3V4zgdWA78r/l6duHohas0tDlZ74GRkrJadcKqOKG1EosxmQRLx0KPgqJtMH4B+IVprcRiWjt72F/sBH25R0N/PSe99ihIj3L6HgW2GBZnASVSylIpZRfwKrD2mmPWAs+b328Algj1bVwLvCql7JRSngdKzNezC7sKqvByN3HTZCesrzIcydmqE5ZOC3HpvkdBfamqBqtTt9CBYifqy20pH/co0Kd7aPGUKNycvEeBLQxBLHB5wOcy87ZBjzE3u28EwkZ4LgBCiIeFEDlCiJza2tHl5Xb19LE0NQpfT/3lf+MbCgk36vZh0H2PgiLzaFSvRebynagvt6V83KPggErd1RlBvh7MTQx1ateobhzlUsqnpJSzpZSzIyJGN6J//I5M/ueBGTZW5kBSVsOVM3ClWGslFuPlrvq5vluo0x4FRVtVgbmQBK2VWEyPuffAYmfsPTBSUlZDX7eqSKpDstOiKKlp4Vytc9bdssW3ohyIH/A5zrxt0GOEEO5AEFA3wnNtiu78owPpH43q1T2UHs2Vli5OXtZZj4LWK3D5kG7dQkcumPty69Et1E/cbPCL0O2MeJmTF6GzhSE4CiQLIRKFEJ6o4O/ma47ZDKw3v78beE+q2sSbgfvNWUWJQDJwxAaaxiZBcTBuhm4fhk96FDjnwzAkZ3eA7Pskp11n7MqvNsfGwrWWMnp03qMgNtiHjNhApy3AaLUhMPv8HwF2AoXA61LKfCHEL4UQt5kP+zsQJoQoAb4HPGo+Nx94HSgAdgDfklLqOK3EAaSsgrKj0FSptRKLCfRW/Vx35lfpq0dB0VYIjIOYaVorsZj+3gMLk8P1GRsbSMpq6GpWsQIdkp0WzYnLDdQ0O1+PAps4DKWU26SUk6WUSVLKx83bfi6l3Gx+3yGlvEdKOUlKmSWlLB1w7uPm86ZIKfVZeN+R6D6VLpoLdW366efa1aaqv6as1GXvgYLKJsob2vXtFuoncRF4+OnWNbosLQopYU9hjdZSPoNOI0cuTESKrlPpPu5R4KS+0s9w7j3oaddtfGBXfjVCwOJUJ+49MFI8vCF5qWoTqsOFlSnRAcSH+jile8gwBHrj41S6/bpMpYsO8mZavI56FJzZpqpgTrhBayWjYndBNbMnhBDu76W1FNuQshpaqqDiuNZKLEYIQXZaNAdL6mjp7NFazqcwDIEeGQOpdKcuN1DV6Hy+0k/R26NGn8m3gpuH1mos5nJ9GwWVTWPDLdRP8jLdL6zs6u1jn5P1KDAMgR6JmwN+kbp9GG41l/jYXejks4LLh6G9XreLyHTVe2Ck+IToemHlrAkhhPh6sNvJyrIbhkCPmEzqx0mnqXRJEf4khvs5pa/0UxRtBTdPmLRUayWjYndBNZOj/EkI99Naim1JWQ1XzkLtWa2VWIy7m4klqVHsKaqhu9d54hyGIdArKauhqwVK92mtxGKUrzSKQ6V1ztvPVUpVAz9xEXgFaK3GYq62dnHkQv3Ycgv1k6LzHgVpUTR39HC4tF5rKR9jGAK9kngTeAbo1j2Une7k/VxrCuDqBd1mC71XVENvnxxbbqF+guIgZvon9Z90xsLkCLw9TE7Vtc8wBHrF3UsFzs5sgz79rcGbHq8yWZzWPVS0DRC6XU28u6Ca6EBvMmODtJZiH1JWq4WVzU76/bkOPp5u3JSsehQ4y8JKwxDomZRV0FqrHgid4WYSLEuL5H1n7eda9I6qbxOgvxF1R3cv+86qvtwmk/4WwY2IlJWAVFldOmRZWhSVjR3klTdpLQUwDIG+SV4GJg/duoeWpTlpP9fGMtWWUqduoQ+Kr9De3Ts23UL9RKapSrA6zR5akhqFSeA07iHDEOgZ7yCYuAgK31HBTZ3htP1c+0eZU/RpCHYXVBPg5c68ifrrpDZihFDuofP7oLNZazUWE+rnyZyEUKcpwGgYAr2TsgqunoeaQq2VWIy3h+pR4HT9XIu2QlgyREzWWonF9PZJ3i2s5uaUSDzdx/jjnbIKerug5F2tlYyK7PRozlQ3c7GuVWsphiHQPVNWAkK3U+TstGjn6ufa3gAXDuh2EdnxS1epa+0ieyy7hfqJnwu+YTr+7psXVjrBjNgwBHonIFqtNNZpnOCWKZG4O1M/15J3oa/nkyqvOmN3QTUeboKbp+iwL7elmNxg8go4uwt6urRWYzHxob6kRAc4hXvIMARjgZRVKrjZcHn4Y52MIF/Vo8Bp0kiL3lHlO2Jna63EYqSU7MyvYn5SOAHe+quNNCpSVkFnI1z8QGsloyI7PZqci/VcadG2QoBhCMYCOu9RsCwtinO1rdr3c+3phOJ3VScsk/4ejeKaFi7WtbmGW6ifpFvAw1fX7qE+Ce9p3KPAqm+7ECJUCLFbCFFs/hsyyDHThRAfCSHyhRCnhRD3Ddj3DyHEeSHESfNrujV6XJbwSapPgU7dQ8ucxVd6fr/qgKXTtNExWWRuODx8IGmxWgCow8y59HGBxAb7aO4atXbY8yiwR0qZDOwxf76WNuCLUsp0YDnwByFE8ID9P5BSTje/Tlqpx3VJWQUXDkKb89QvGSnjgn3IjA3S3j1UsEmV7Zh4s7Y6RsmOvCqmxQcTFeittRTHkrIamiug4oTWSixGCMGytCgOFNfS1qVdjwJrDcFa4Hnz++eB2689QEp5VkpZbH5fAdQALhDJcjApq0D2wtmdWisZFdlpUaqfa5NGPQp6e5RrbfKtqnyHzii72kZueSMrMsZgkbnhmHyruUeBft1DnT197D97RTMN1hqCKCllfxf1KuC6c1IhRBbgCZwbsPlxs8voSSHEkE+gEOJhIUSOECKnttZJC5VpScwMCBinW/dQdno0UsK7WvlKL30IbXWQdps297eSHXlqNuWShsA3FCYs0K0hmJMYSpCPh6arjIc1BEKId4UQeYO81g48TqrqSUM66YQQMcALwJeklP2FuH8MpABzgFDgR0OdL6V8Sko5W0o5OyLCmFB8BpNJzQpK9qiG6zpjcpQ/E8J8tWvYUbAZ3H1023tgR14VqTGBTAgbY70HRkrKKqgthLpzwx/rZHi4mViSEsl7RTX0aNSjYFhDIKVcKqXMGOS1Cag2/8D3/9APOpwTQgQCW4GfSCkPDbh2pVR0As8BWbb4j3JZUlapRuule7VWYjFCCJalRmnTz7WvT82kJi0BT/39kNY0dXDs0lXXnA30018lVqezguz0KBraujl64aom97fWNbQZWG9+vx7YdO0BQghP4C3gn1LKDdfs6zciAhVfyLNSj2uTcKOqP6TbhyFam36u5TnQXAlpa4c/1gnZmV+FlC7qFuonZAJEZ+r2u78wOQJPdxM7NUqYsNYQPAEsE0IUA0vNnxFCzBZCPGM+5l7gJuDBQdJEXxJC5AK5QDjwayv1uDZuHjB5uSqa1qtdBsJomTUhhDA/T3Y4+mEo2KSquE6+1bH3tRHb86qYGOHHpEh/raVoS8pq1We6Rduc/NHg5+XOTckR7Mir0qTullWGQEpZJ6VcIqVMNruQ6s3bc6SUXzG/f1FK6TEgRfTjNFEp5WIpZabZ1fR5KaXGK4rGACmrVMP1Sx9prcRi3EyC7PRo3iuspqPbQT0KpITCLSpl1Ft/TVzqW7s4fL6eFRnRqIm1C5O6BpBQuFlrJaNi1dRoqpo6OHHZ8XW39Ld80uD6JC0BNy/dTpFXZcbQ2qUaqziEqtPQcFG32UK7C6ro7ZOsyIjRWor2RKapqrEFn/FQ64IlqVF4upnYlls5/ME2xjAEYw0vf/NKy626XGk5b2IoIb4ejnsYCjarHHSd9h7YnldFXIgP6eMCtZaiPUJA+u1w4QNo0V+KeaC3BwuTw9meW+lw95BhCMYiKaug8ZIa7eoMdzcTt6ZHs6ewxjHuocItkHAD+OmviUtjezcHS64YbqGBpK0F2QdFW7RWMipWZsZQ0djh8LLshiEYi0xZAcKkW/fQyswYWjp72G9v91DtGbhyBlL16RbaW1RDd69kueEW+oSoDAhNgvy3tVYyKpamReHhJhzuHjIMwVjELxzGz9etIZifFEawI9xDBeagok57D2zPqyQq0IsZ8cHDH+wqfOweOgCt2pVsGC1BPh7cOCmcbblVSAe6dg1DMFZJWQXVeVBfqrUSi/FwM3FrWjTv2ts9VLgJ4rIgUH8j6rauHvadrWV5ejQmk+EW+hRpt5vdQ/ost7IyM4byhnZOlzU67J6GIRirpK5Rf3U6RV45VbmHDhTbaVR3pQSqciH9Dvtc3868V1RDR3ef4RYajOhMCEnU7Xd/WVoU7ibHuocMQzBWCR6vWljmbdRayahYkBRGkI8d3UP5GwGzG0GHvHOqkogAL7ISQ7WW4nz0u4fO74fWOq3VWEywryc3TApnW16lw9xDhiEYy2TcBdW5UHtWayUW4+FmIjstincLqunssYN7KO9NFUcJHGf7a9uZ5o5u9p6pYVVmDG6GW2hw0m5XZdl16h5alRnD5fp28sqbHHI/wxCMZdJuB4R59Ks/Vk6Nobmzhw9s7R6qLoDaIsi407bXdRDvFlbT2dPHmmmGW2hIYqZBSAIU6Nc95GYSbHWQe8gwBGOZwBiYcINyD+lwcdkNSeEEervb/mHI36jSa3VaZG7LqUpig32YEf+ZzrAG/QihBkKl+3TZtS/Ez5MFSWFsd5B7yDAEY52MO1SufE2B1kosxtPdRHZ6NLtt6R6SUhnGhIXgH2mbazqQhrYuDhTXsmpqjJEtNBxpa5V7qFC/i8su1rWRX2F/95BhCMY6qWvV6DfvTa2VjIpVU2No7uixXRu/qtNQf063bqGd+VV090pWTzXcQsMybgaEToS8DcMf64QsT4/Gw02w6WS53e9lGIKxjn8EJC7SrXvoxknhhPp52u5hyNsIJnfdriZ+53QlE8J8yYzVX6VUhyMEZN4D5w9Ak+MLuVlLiJ8niyZHsPlUBb12rj1kGAJXIONOuHoeKk5orcRiPNxMrMqM4d3Caus7l0mp4gMTb1Z9bnXGlZZODpZcYfXUGKO20EjJvAeQuk2YuG16LNVNnRw5b984h1WGQAgRKoTYLYQoNv8dNHolhOgd0JRm84DtiUKIw0KIEiHEa+ZuZga2JmW1aryi04dh7fRxdHT3scvahjXlx6HhEqTr0y20Pa+KPglrpukv5VUzwpMhZjrkvqG1klGxNDUSX083Np+yr3vI2hnBo8AeKWUysMf8eTDaBzSlGTgn/y3wpJRyEnAV+LKVegwGwzdUlabOf1v159UZsyaEEBfiw6aTFdZdKPcNcPNU5Td0yJZTFUyK9GdKVIDWUvRF5j1qNnylRGslFuPr6U52WhTbcqvss57GjLWGYC3wvPn986i+wyPC3Kd4MdAfybHofAMLybgLGi+rVn46QwjBbdPG8UHJFa60dI7uIr09Kmg4eTn46K9IW3lDO0fO13PbtHGGW8hSMu4EhG6Dxmunx9LY3m27hIlBsNYQREkp+6MwVUDUEMd5CyFyhBCHhBD9P/ZhQIOUst/xWwbEDnUjIcTD5mvk1Nbqr+mE5qSsAg9fOP2q1kpGxe0zYuntk2w9PcqgX+leaK2FqffZVpiDePuEcg3cMWPIR8RgKALHQcKNakaox4SJZJUw8bYds4eGNQRCiHeFEHmDvD61GkeqVQ9D/StPkFLOBh4A/iCESLJUqJTyKSnlbCnl7IiICEtPN/DyV4Xo8t6C7g6t1VjM5KgAUqIDRv8wnHoVfEIgOdu2whyAlJK3TpQzJyGE+FBfreXok8x7oK4EKk9qrcRiPk6YKLBBwsQQDGsIzE3pMwZ5bQKqhRAxAOa/NUNco9z8txR4H5gB1AHBQgh382FxgP0TZl2ZafdDZyOc3aG1klFx+4xYTlxq4FJdm2Undjar3gzpd4K7/vIR8sqbKKlp4Y4ZcVpL0S9pt6n4UK5e3UPj6OyxQcLEEFjrGtoMrDe/Xw98pmu0ECJECOFlfh8O3AAUmGcQe4G7r3e+gQ1JXAQBMWp0rEP6s2UszqAo2Aw97bp1C715vAxP86jQYJT0zwbz3oQ+B7RAtTEzx4cQG+zD29YmTAyBtYbgCWCZEKIYWGr+jBBithDiGfMxqUCOEOIU6of/CSllf72DHwHfE0KUoGIGf7dSj8H1MLlB5t1QsluX3Ztig33ISghl44lyy+qvnH5N1aePz7KfODvR3dvHllMVLEmNJMjXQ2s5+mbqvdBcqeJFOsNkEqydPo4Pimupaba9a9cqQyClrJNSLpFSJptdSPXm7TlSyq+Y338opcyUUk4z//37gPNLpZRZUspJUsp7pJSjTAkxGDHT1kFfj25LTtw1K5bS2laOXxphc+/GclWXfup9aqWpzjhQXEtdaxd3zjTcQlYzebmaGZx4SWslo+LOmXHcPSuO7l7bB7yNlcWuRlQ6RGXq1j20auo4fDzc2HDs8shOyH0DkGo0qEM2Hi8nxNeDRZONBAmrcfeCzHtVvKj9qtZqLGZSpD+/u3sascE+Nr+2YQhckWn3Q8VxXTas8fdyZ2VmDFtOVdLWNUwGhZTK4MXNgTCLE9U0p6mjm90F1ayZNg5Pd+NRtQkzPge9nboNGtsL49vlimTeoyqSnnpZayWj4t7ZcbR09rAjb5gMirIcqC2EGZ93jDAb886pSjp7+oy1A7YkZpqaEZ/U53ffXhiGwBUJiFIZFCdfht5urdVYTFZiKBPCfHk9Zxj30Il/goefWlWtQ147eokpUQFMj9ffSminZvoDakZcU6i1EqfBMASuysz10FINZ3dqrcRihBDcMyuOQ6X1Q68p6GyG3Dch/Q7w0l9tnoKKJk6VNXJ/VrxRUsLWTL1XlSI/8aLWSpwGwxC4KsnZ4B8Nx58f/lgn5M6ZcQjB0EHj/LeguxVmftGxwmzEa0cv4eluMtxC9sAvXGUQnX5NlzNie2AYAlfFzV0FzkrehcYyrdVYzLhgHxYmR7DhWNngTTuOvwDhU3S5dqCju5e3TpSzIiOaYF/9rYTWBTM+r2pPFe/SWolTYBgCV2bGF0D26Tav+t7ZcVQ0drC/+JoihDWFUHZEzQZ06FbZnldJU0cP982J11rK2GXSMrXKPudZrZU4BYYhcGVCE1W3rhMv6HLZfXZaNOH+nrx06OKndxx/QTXimXa/NsKs5NUjl0kI82X+xDCtpYxd3NxVnKxkD9SXaq1GcwxD4OrM/KLqU3BOf8vuPd1N3D9nPHuKarhcbw4ad7ertNiUlcoXrDNKa1s4fL6ee+cYQWK7M2u9SqPOeU5rJZpjGAJXJ2U1+Ibpdoq8bu54BPDKkUtqQ96batXonK9qqmu0vHDoIu4mwd2zjJISdidwnBownHhRl6XZbYlhCFwddy+Y9SCc3Q5XL2itxmJig31YkhrFa0cv09ndA4f/BhGpqhGJzmjt7GFDThkrM2OIDPDWWo5rMPvL0F4PBa5d+NgwBAbqYUDA0WeGPdQZ+cK8CdS1dnH4wE6oOg1ZX9VlkHjj8TKaO3tYvyBBaymuQ+IiCJuk2+++rTAMgQEExaruZcdfgC4Lm744ATdOCichzBdx5CnwCtRl3wEpJc9/dJHM2CBmjjdWEjsMkwlmP6SyzCpPa61GMwxDYKCY+zXoaIDc17VWYjEmk+CrM/yY236AuuS7VVtOnXGwpI6SmhbWL0gwgsSOZvoDqp/34b9prUQzDENgoBg/H6Iz1cOgwwbfd8ndeIpe/ta2WGspo+IfH14g1M+T1VONLmQOxycEpn9OrTRudXECXgAAEYJJREFUtk8rSGfHKkMghAgVQuwWQhSb/4YMcswtQoiTA14dQojbzfv+IYQ4P2DfdGv0GFiBEDD361BToBq56InudrxP/J3iwAU8W+RGRUO71oos4sKVVvYUVbMuKx5vDzet5bgm876hGjYdeUprJZpg7YzgUWCPlDIZ2GP+/CmklHullNOllNOBxUAbMHBd9w/690spT1qpx8AaMu4Gv0g4+AetlVjGyZegrY6gZd9HAs8dPK+1Iot46kApHm4m1s9P0FqK6xKWBKmr4ejfobNFazUOx1pDsBbor1r2PHD7MMffDWyXUuovIukKeHjD/G/Cufeg4oTWakZGXy98+GeInUVkxhJWZcbwypHLNHXoo5hYTXMHG46VcdfMOCIDjZRRTVnwHRUnO6nPkivWYK0hiJJSVprfVwFRwxx/P/DKNdseF0KcFkI8KYTwGupEIcTDQogcIURObW3tUIcZWMvsh1TmzQc6mRUUblbrH274FxCCh2+aSEtnD68cvqS1shHx3MELdPf28fBNE7WWYhCfBXFZ8NFfdFlyxRqGNQRCiHeFEHmDvNYOPE5KKYEho4xCiBggExhYAP/HQAowBwgFfjTU+VLKp6SUs6WUsyMijP6tdsM7COZ8RS2wuVKitZrrIyUc/BOETlQrpIGM2CAWJIXx7MHzdPY498Pc3NHNi4cusiIjmsRwP63lGAAs+DY0XIS8jVorcSjDGgIp5VIpZcYgr01AtfkHvv+HvuY6l7oXeEtK+fGcXUpZKRWdwHOA/moGj0XmfUOtOP7wj1oruT4le1SnqQXfAdMnQdZv3TKJ6qZOXjs6wgb3GvHioUs0d/Tw9UX666c8ZklZDZFpsO+3LjUrsNY1tBlYb36/HrjeOu11XOMWGmBEBCq+kGelHgNb4B+p6rWffAUanPTHVErY+zgEjVepfwNYkBTGnIQQ/rK3hI5u53yYmzu6eWr/ORZNjmBqnLGAzGkwmWDRj6Cu2KVmBdYagieAZUKIYmCp+TNCiNlCiI/XbAshEoB4YN81578khMgFcoFw4NdW6jGwFTd+V6WU7vut1koG5+xONRtY9ANw/3TzFiEE3106meqmTl494pyxgucOXuBqWzffz56stRSDa0m9Tc0K9v/OZWYFVhkCKWWdlHKJlDLZ7EKqN2/PkVJ+ZcBxF6SUsVLKvmvOXyylzDS7mj4vpXS9vC1nJShOxQpOvgRXirVW82n6ZwMhCTBt3aCHzE8KIysxlP99/5zTzQoa27p5+kApy9KijNmAM2IywaIfwpWzLjMrMFYWGwzNjd8Ddx/1o+tMFL2jisvd9ENw8xj0kP5ZQU1zJy98dHHQY7Ti6QOlNHf08L1lxmzAaUldC5HpsPfX0NOptRq7YxgCg6Hxj1DrCvLfgspTWqtR9HTB7scgfPKwxeXmJ4WxaHIEf3qvmPrWLgcJvD7VTR08e/A8q6bGkBoTqLUcg6EwmSD7lyo1+cjTWqv5BDuVfzEMgcH1mf+IqsWy8yfOUYPo6DNQfw6yf63aDQ7DT1el0tbVyx/fPesAccPzux1n6OmV/PDWKVpLMRiOSUshabGKFbTVa60GqnLhmSV2Ses2DIHB9fEJhsU/hQsHtG/e0VYP+56AibdAcvaITkmOCmBdVjwvHr5ESY22IajTZQ28ebyML92YwIQwY92ALsj+NXQ2w/7/1FaHlLD9R1B/HnxDbX55wxAYDM+sL0FUJuz6qbb9CvY+rh7KWx+3qPHMvy6djK+HG7/eWoDUaFYjpeSXWwoI9/fkkVsmaaLBYBREpav05CNPQ02Rdjry3oSLB2HJzw1DYKARJjdY8VvV5P6DJ7XRcPmIKgg256vq4bSAcH8v/mVpMu+fqWVrbuXwJ9iBV49eJufiVX5w6xQCvAcPcBs4KUseUz0utvwL9PUNf7ytaW9Qg7CYaTDzi3a5hWEIDEZGwg0qOPvBfytfpSPp6VIPYeA4WPKzUV3iwQUJZMYG8YvNBTS2ObYgXXVTB7/ZVsjcxFDumRXv0Hsb2AD/COUiunwIjj8//PG2ZvfPoKUaVv/hUyvobYlhCAxGzvInwCcU3v4m9Drwx/TDP6o+Cav+C7wCRnUJdzcTT9yVydW2Lh7fVmBjgUMjpeRnb+fR1dPHE3dNxWQyuo/pkumfg4SFKmOtqcJx9z23F47/U9VAip1pt9sYhsBg5PiGwur/Vjn8B/7bMfcsPwbvPwHpd8KUFVZdKn1cEA/fNJHXc8rYme+YTlRvnShnV0E1/7p0slFYTs8IAWv+CH3dsPFhx6w47miELd+BsElw84/teivDEBhYRuoayLxHZe+cP2Dfe3U2w5tfAf9oZYBswHeXTiYzNogfbjhNuZ07mZXWtvDTt/PISgzlqwsT7XovAwcQlgQrfqcy6A7auSCjlLD529BYDrf/FTx87Ho7wxAYWM7qJyE0CTY8BM3V9rmHlLD5O2pBz11Pq7UMNsDT3cSf182gp7ePf3nlBN299gn+tXf18sjLJ/B0N/HH+6fj7mY8amOCGZ+H9DtUBtulQ/a7z9FnVLr20sdUnwQ7Y3w7DSzHKwDu/Sd0tcAbD9pnCf6B30P+Rlj8M5iwwKaX/v/tnXtw1cUVxz9fEx4CSkAQIbxrUBGMAkVQ26IiQa2ggg/QggrDdKz12Vas9dWxTrU+UFtRfI5UBUS0DFqQIKitHUqoyitAAj5IKgZohCmiDWT7x27gysN4c2/uJfd3PjO/4bfnt7N3zz03nN/unt3TtU1z7rmwN0WfVHLbayuSHlJaXe246eUPKN64jQcvzqd9y/p9mzNSiOQXbXO6wLTLoLIeji9ZtxDmToS8Ahj48+S3vx/MERh1o11PGP5H+PS95M+ZLp8Jb93to5ROuyF57cYw/MRcrjn9aKYt2cDjb69Patt/eHMNbyzfyK/PPo4zjq0taZ/R4Dg0B0bP8OsFL16S3F3HFcUwY4w/QmXEk/6oixRgjsCoO71GwJDfwarX4PUbkxNjvWq2dyxdToXzHolr41i83HhWD87L78C9c1fzzN+Sk/B+UuFaJi9ax6j+nRlv6wKZS5uj4eKp/riTqefDjsrE29y0Bp4f7tcDRs/w2QJThDkCIzFOuQZ+cBMsfQ5mjfcx/3Xlw+l+3SG3L4yeDo3qN5n7IYeIBy7KZ+jxR/HbOat4bFFpnaeJqqsd981dzaTCEkb27cjd5/dC9ejEjIOA7j+CS17wb/HPD08srPSzZfDsOf5+zGzISe1+E3MERuKceTsMvstvg39+mI90iIddVVB4J7w6AToPgMtn1nm/QLw0zj6ER0efxLD8Dtw3dw3XTvuAL/+3M642tu6o4uoX/sVji9Yxqn8n7h1xAlm2XyAa9BjincHmUphyOmxYEn8by2fC00N8etgr3oAjj01+P2shIUcg6SJJKyVVS+r3LfWGSlojqVTSxBh5N0mLg3y6pMYHasM4yDntehjxtH+zefxUP0L4LusGZUXw1GB/dEXfK+DyWSkdEgM0yvKRPb8sOIY5y/7N0Env8vbaTbWODpxzzFu5kYKH3mF+8ef85tzjuOeC3uYEokaPITB+vv+P/JkCmH+7D32uja3lfj3glXHQ4USYsMhPOaUBJRIxIek4oBp4AviFc65oP3WygLXAWUAZsAQY5ZxbJWkGMMs5N03S48CHzrnJtX1uv379XFHRPh9lHAxsLvHHQXzyd78R5vvjoUcBtO6+p86OSvjoHXj/z1DyJjRvC+c+CD2Hpa/fgX+s28Ktry5n/ebt9Omcw+iTuzDomLa0adFkd52KbV9RWFzBtCWfsqxsK3lHtuD+i/LJ72TZxiLNjkp48zZ4fyo0zYF+V/pQ03a99hwNsfNrH3a6/GVYNh0QDLoZTrn2gEmWkomkpc65fV7aE3IEMY0v4sCOYCBwp3OuIJRrtsj9HtgEHOWc27l3vW/DHMFBjnM+Bvq9R6E82KnJ4dDsCKj60p+bAnBYe+h3FQy42h/qdZDwVdUuZhRt4Ml317PhP37TWatmjWjRNJvtX+/aneSme9vm/PSH3+PCPrm2T8DYQ/lSeOcBWPtXcNWQ1QQOa+f/LraVe1n2oZB/qY+Ka9UlZV07kCOoPbNH4uQCG2LKZcDJwBHAF865nTHy3AM1ImkCMAGgc+fO9dNTIzlIcPz5/tpcAusXwZZS2L4ZGjfzuYY79vf7A+rpEK1EaNooizEDu/KTAV1YUb6N99Zt5uMtX/JV1S6aNsoi78gW9O/WmuM7HG4Lwsa+5PaFUS/633tpIXy+Ev5b4f8uWnaC9if4hDeND54jR2p1BJIKgaP28+hW51zKMpU456YAU8CPCFL1uUaCtMnzVwNEEr07tqR3x9SuWRgZQvM2/q2/AVCrI3DODU7wM8qB2FiojkG2BciRlB1GBTVywzAMI4WkYmJzCZAXIoQaA5cCs51fnFgIjAz1xgJpzoVoGIYRPRINH71AUhkwEHhd0rwg7yDpDYDwtn8NMA8oBmY451aGJm4GbpRUil8zeDqR/hiGYRjxk5SooVRjUUOGYRjxc6CoIYt5MwzDiDjmCAzDMCKOOQLDMIyIY47AMAwj4jTIxWJJm4C6pgZqA2xOYncaAqZzNDCdo0EiOndxzrXdW9ggHUEiSCra36p5JmM6RwPTORrUh842NWQYhhFxzBEYhmFEnCg6ginp7kAaMJ2jgekcDZKuc+TWCAzDMIxvEsURgWEYhhGDOQLDMIyIEylHIGmopDWSSiVNTHd/koGkTpIWSlolaaWk64K8taT5kkrCv62CXJIeCd/BMkl90qtB3ZGUJel9SXNCuZukxUG36eHYcyQ1CeXS8LxrOvtdVyTlSJopabWkYkkDM93Okm4Iv+sVkl6S1DTT7CzpGUkVklbEyOK2q6SxoX6JpLHx9CEyjkBSFvAn4GygJzBKUs/09iop7ARucs71BAYAPwt6TQQWOOfygAWhDF7/vHBNACanvstJ4zr80eY13As85Jw7GqgExgX5OKAyyB8K9RoiDwNznXPHAvl43TPWzpJygWuBfs65XkAWPp9Jptn5OWDoXrK47CqpNXAHPg1wf+COGufxnXDOReLC50yYF1O+Bbgl3f2qBz3/ApwFrAHaB1l7YE24fwIYFVN/d72GdOEz2i0AzgDmAMLvtsze2974XBgDw312qKd06xCnvi2Bj/budybbmT35zlsHu80BCjLRzkBXYEVd7QqMAp6IkX+jXm1XZEYE7PlR1VAWZBlDGAqfBCwG2jnnPguPNgLtwn2mfA+TgF8B1aF8BPCF84mQ4Jt67dY5PN8a6jckugGbgGfDdNhTkpqTwXZ2zpUD9wOfAp/h7baUzLZzDfHaNSF7R8kRZDSSWgCvANc757bFPnP+FSFj4oQl/RiocM4tTXdfUkg20AeY7Jw7CdjOnukCICPt3AoYjneCHYDm7DuFkvGkwq5RcgTlQKeYcscga/BIaoR3Ai8452YF8eeS2ofn7YGKIM+E7+FUYJikj4Fp+Omhh4EcSdmhTqxeu3UOz1sCW1LZ4SRQBpQ55xaH8ky8Y8hkOw8GPnLObXLOVQGz8LbPZDvXEK9dE7J3lBzBEiAvRBw0xi86zU5znxJGkvC5noudcw/GPJoN1EQOjMWvHdTIx4TogwHA1pghaIPAOXeLc66jc64r3o5vOecuAxYCI0O1vXWu+S5GhvoN6s3ZObcR2CDpmCA6E1hFBtsZPyU0QFKz8Duv0Tlj7RxDvHadBwyR1CqMpIYE2Xcj3YskKV6QOQdYC6wDbk13f5Kk02n4YeMy4INwnYOfG10AlACFQOtQX/joqXXAcnxERtr1SED/QcCccN8d+CdQCrwMNAnypqFcGp53T3e/66jriUBRsPVrQKtMtzNwF7AaWAFMBZpkmp2Bl/BrIFX4kd+4utgVuCroXgpcGU8f7IgJwzCMiBOlqSHDMAxjP5gjMAzDiDjmCAzDMCKOOQLDMIyIY47AMAwj4pgjMAzDiDjmCAzDMCLO/wHIWvqAaYANKQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from jax import vmap\n", "\n", "t = jnp.linspace(0, 10, 1000)\n", "\n", "plt.plot(jnp.sin(t))\n", "plt.plot(vmap(grad(jnp.sin))(t))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "iS8nRuBZYLcD", "outputId": "299dc977-ff2f-43a4-c0d2-9fa6c7eaeeb2" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 25, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9d3yU15X//z6jXpCECuogAQIVRBWywd0GmWawsR0bp+AkjlN3s0l+yTrfZJNsNskrye5+k82mfOMUx3HibsdgU0wzxrFNEVUFAUJUdSSQEEL9/v54Ro6MJUtTn3lm7vv1mpdm7tM+o5lnzr3nnnuOKKXQaDQaTeBiM1uARqPRaMxFGwKNRqMJcLQh0Gg0mgBHGwKNRqMJcLQh0Gg0mgAn2GwBzpCYmKiysrLMlqHRaDSWYv/+/ReUUknXtlvSEGRlZVFaWmq2DI1Go7EUInJmuHbtGtJoNJoARxsCjUajCXC0IdBoNJoARxsCjUajCXC0IdBoNJoAxy2GQET+KCJNIlI+wnYRkV+ISLWIHBGRuUO2rRWRE/bHWnfo0Wg0Gs3YcdeI4E/Akg/ZvhTIsT8eBX4DICLxwHeB64Bi4LsiMt5NmjQajUYzBtyyjkAptUtEsj5kl1XAn5WR83q3iMSJSCpwK7BVKdUKICJbMQzKM+7Q5Q/09g9QUddORV0blzp7EYG02AjmThzPxIRIs+VpNB5DKcXJ5iuU1V6ioa2bvv4BkmPCyUuNoSAtBptNzJboN3hrQVk6cG7I6/P2tpHaP4CIPIoxmmDixImeUelDNLZ38btdNfztYC0tV3qG3SdnQjSfWDCJ+4syCQ8J8rJCjcYzXO3p5y+7z/DsvrOcbL4y7D6J0WE8MD+DtQuzmDAu3MsK/Q/LrCxWSj0OPA5QVFTkt9V0evsH+PUbJ/nNm9X09SsW5yezrDCV2ZlxTIgJo39Acba1k90nW3j5YC3/tq6C3+6q4Xt3FbAoP9ls+RqNS2w4Us+/v1pB0+Vu5meN55M3ZFOcHU/m+EhsNmho6+LA2YtsLGvg1ztP8qe3T/OVxdN4eGEWwUE69sVZvGUIaoHMIa8z7G21GO6hoe07vaTJ56i7dJXP/2U/h8+3sXxmKv96Z+6w7p/clBhyU2JYuzCLd0+28L1XK3jkz6WsXTCJ/7M8j7BgPTrQWIuu3n7+z8tlvHywlpkZsfzyobkUZ8d/YL9JCVFMSojinjkZ1DR38B+vVfKDDUfZdrSRX6yZo0cHTiLuKlVpnyN4TSk1Y5hty4EvAcswJoZ/oZQqtk8W7wcGo4gOAPMG5wxGoqioSPlbrqHy2jY+9ad9XO3p5yf3zWRZYeqYj+3pG+Cnm6v4/d9PcV12PL9fW8S48BAPqtVo3EdLRzef+XMpB85e4st35PBPt08dc+9eKcXLB2r51itljI8M5alPX8fUCdEeVmxdRGS/Uqro2nZ3hY8+A7wLTBeR8yLyaRH5nIh8zr7LRqAGqAZ+B3wBwP6D/x/APvvj+6MZAX/kwNmLPPDbdwkJsvHi5xc6ZAQAQoNtfHtFPj9/YDb7z1zkod/t4eII8woajS9xoaObj/z2XSrq2vnNR+fylcXTHHLxiAj3zsvgpc8vpLdf8cBv36W8ts2Div0Tt40IvIk/jQgq6tpY8/huxkeF8vxnF5Ac49rQdkdVI5/7ywEK0mJ4+pHriQjVbiKNb3Kps4c1v9vDqQsdPPnJYq6bnODS+U5duMLHfr+Hq739vPz5hWQlRrlJqf/g0RGBxjnq267y8BP7iA4L5q+PXOeyEQC4PTeZXzw4h8PnLvGlpw/Q1z/gBqUajXvp7R/gc3/Zz8mmDn73iSKXjQBAdmIUT326GKUUa5/Yy4WObjcoDQy0ITCJrt5+PvfUfjq7+/jTp4rJGO++NQFLZqTw7ysL2F7VxH9vPe6282o07uL7r1ayu6aVn9xXyE05H6iT4jSTk6L5w8PzaWzv4gt/1R2hsaINgUl8d10Fh8+38bMHZjMteZzbz//xBVmsKc7kNztPsrWy0e3n12ic5cX953lq9xk+e/Nk7pmT4fbzz504nh/dU8jeU6381xbdERoL2hCYwObyep4rPccXbp1CSUGKx67z3bsKmJEew9eeP0R921WPXUejGStnWzr57rpyrsuO5xtLcj12ndVzM3jouon8vzdP8saxJo9dx1/QhsDLNLV38c2XyyhMj+Uri6d59FrhIUH8cs1cevsVj71UhhUDAzT+Q1//AP/y3EFsNuH/PjCbIA+niPjOinymJ4/jsZeO0NbZ69FrWR1tCLyIUop/fekIV3v7+dkDswnxwkrIrMQoHluay5vHm3m+9NzoB2g0HuLxt2o4cPYSP7h7BulxER6/XnhIEP91/ywudPTw/dcqPX49K6MNgRfZVN7AG8ea+fqduV5d9PLx6ydx/eR4/uO1ozS0dXntuhrNIGdbOvmfbSdYUpDCqtnDphPzCIUZsXz+lim8dOA8O6r0XNlIaEPgJS539fLvr1aQnxrD2gWTvHptm0346b2z6O0f4Ecbj3r12hqNUorvrC8n2CZ8d2W+16//T3dMZeqEaL63vpKu3n6vX98KaEPgJX629QRNl7v54T0zTEmONTEhks/eMoX1h+vYXdPi9etrApfN5Q3sPNbMV0umkxrreZfQtYQFB/G9uwo429rJ73bVeP36VkAbAi9Q3dTBk++e5sH5E5kz0by6O5+/ZQrpcRF8d12Fjq/WeIXuvn5+uPEouSnjvD4SHsqNOYksKUjhVzurqbukI+iuRRsCL/DTzVWEB9v4Wolno4RGIyI0iH9bkcexxss8s09PHGs8z1PvnuH8xat8e3m+6Wmiv7U8D6Xgx5uqTNXhi2hD4GFKT7eypbKRz90yhcToMLPlcGdBCkWTxvO/209wtUf7SzWeo62zl//dUc3N05K4MSfRbDlkxkfyqRuzWX+4jsq6drPl+BTaEHgQpRQ/2niUCePC+PRN2WbLAYxsjd9YkkvT5W7+9M5ps+Vo/Jhf76ymvauXxzy4cMxRPnfzFGLCg/mvLcfMluJTaEPgQbZUNnLg7CW+ungakaG+UwyuODueW6cn8Zud1XqhjcYjNLR18cQ7p7lnTjr5aTFmy3mP2MgQPnvLFHZUNVF6OuAy3o+INgQeQinFL7afICshkvvmuT+fiqt8/c7ptHf18dtdJ82WovFD/t+bJ+kfUHxlkbnzYsPxyRuySIwO46ebj+nV9na0IfAQO6qaqKhr5wu3jb3akjcpSItl+cxUnnzntB4VaNxK0+Uuntl7ltVz0smMd19WXXcRGRrMP90+lb2nW3lXh1ID7qtQtkREjolItYg8Nsz2n4nIIfvjuIhcGrKtf8i29e7QYzZKKX6xo5qM8RHcM8d7qygd5Uu3TeVKTz9PvnvabCkaP+J3u2roG1B86fapZksZkQfmZ5IYHcZvduoRMbjBEIhIEPArYCmQD6wRkfctH1RKfUUpNVspNRv4X+DlIZuvDm5TSq10VY8v8NaJCxw+d4kv3DrVK/mEnCUvNYY7cifwx7dPcaW7z2w5Gj/gQkc3f9l9llWz05iU4LsVwsJDgnjkpuz37tVAxx0zmMVAtVKqBkBEngVWASNleVoDfNcN13WcTf8KZ3d75twFd8ONXwHglzuqSY0N5955vjsaGOQLt03l3t+8wzN7z/LITZPNlqOp3Q+vfxt6O43XMjRDpzjZNmSbw22MvN/g63FpsPIXEBTCE2+foquvny/e5rujgUE+et1Efv1GNb/eWc1vP/6B6o0BhTsMQTowdHXSeeC64XYUkUlANrBjSHO4iJQCfcCPlVKvjHDso8CjABMnTnROaXgcRCc7d+yH0XYOdvwQClZzqCOWvadb+fbyPMKCfb9e8LxJ41kwOYHHd9Xw8QWTLKHZb1EKNnwNmo9D1o2Aev8248nIbe+b+Bxr24ecayzX7OuCmp2Qs4jOaav4y+6z3JmfwpQk7yVVdJZx4SE8vDCLX+yo5kTjZXI8UCDKKng7pvFB4EWl1NCVTJOUUrUiMhnYISJlSqkPOO6UUo8Dj4NRvN6pq9/2TacOG5W2WvifmbD7N/zh0gOMCwvmgfmZnrmWB/jCbVP4+B/2sv5QHfcXWUe333F2N9QdhOX/DfMfMVvN2BgYgF8Wwbu/4qX2ubRd7eUzN/vGmpmx8PAN2fzurVM8vquG/7x/ltlyTMMdDuxaYOivR4a9bTgeBJ4Z2qCUqrX/rQF2AnPcoMm7xKZDwWoGDj7FrrKTPFicybjwELNVjZkbpyYyLTmaJ94+rcPpzGT3ryBiPMxaY7aSsWOzwfWfh9r97N61mdmZccw1MZ+Wo8RHhbJ6bjrrDtfREsDF7t1hCPYBOSKSLSKhGD/2H4j+EZFcYDzw7pC28SISZn+eCNzAyHMLvs31n8PW08F9tjdZuzDLbDUOISJ88oZsKuvb2XtKL7IxhdZTULUB5n0SQn13knVYZj9Eb2gsy6+8zCM3ZSPDzS/4MA8vzKKnb4Bn9p41W4ppuGwIlFJ9wJeA14GjwPNKqQoR+b6IDI0CehB4Vr2/y5kHlIrIYeANjDkCSxqCjsRZHFTT+VzENjJizc8p5Ch3z04nLjKEJ94+bbaUwGTPb0GCoPhRs5U4TmgUr4XcyZ1BpSxJs16vOid5HDflJPLU7jP0BmhWXrfENiqlNiqlpimlpiilfmhv+45Sav2Qfb6nlHrsmuPeUUoVKqVm2f/+wR16zOCF0nP8vvdOknrr4PjrZstxmIjQIB4qnsiWygbOtXaaLSew6GqHg0/BjNUQk2q2Gocpr23jxy03I2IjuPRxs+U4xSdvyKKxvZuNZfVmSzEF3w1ytxBKKZ7afYaG9MUQkwF7fmO2JKf4+IJJiAh/fve02VICi7LnoacDij9rthKn+MvuM7SHJNGXuxIO/hV6rpgtyWFunTaB7MSogB0Ra0PgBnbXtFLTfIU110+G4kfg1C5osl5JyNTYCJbOSOHZfed0impvoRTs+yOkzoL0uWarcZj2rl7WHapj5aw0Qq9/FLrboPzl0Q/0MWw2Ye2CSRw6d4lDAbjATBsCN/DXPWeICQ9mxcxUmP0xsIXA/ifNluUUH7t+Epe7+tgQoENkr3NuLzRVQNGnh1/E5eO8crCWq739fPT6iTDxekjKg9I/mi3LKe6dl0FESBDP7Am8SWNtCFyk+XI3r1c0cN+8TMJDgiA6CfLugsPPQK/1SuJdlx3P5KSogI6g8Cqlf4SwGJhxr9lKHEYpxV92n2FmRiwzM+IMQ1b0Kag7YKyHsBjjwkNYOSuN9YfruNwVWIkYtSFwkRf2n6O3X/HQdUNWO897GLouQaX1cuiJCA8VT2T/mYsca7hsthz/prMVKv4GMx+AMN9fiXstpWcucryxg48O/e7PegBCIi07Klhz3USu9vaz7lCd2VK8ijYELjAwoHh6z1munxzP1AlDbuSsmyB+Mux/wjxxLrB6bgahQTY9KvA0h/4K/d1GL9qC/HX3GcaFB3PXrLR/NIbHQuF9UPYidLWZJ85JZmXEkpcaw9N7zgbU4kptCFxg14lmzl+8ykevm/T+DTYbzF0LZ9+FJusVyo6PCmXJjBRePnCerl49aewRBgag9AmYuACS80ff38dovdLDxrIG7p2b8cHqe0WfMpLmHXneHHEuYIyIM6msb6es1nqGzFm0IXCBZ/eeIyEqlDsLUj64cfZHjUnjA9acNF5TPJH2rj42HNGTxh7h7DvQetJwI1qQvx2spad/gDXFwySATJsDKTONtREWZNWcdMJDAmtErA2Bk7Re6WF7VSP3zEknNHiYf2N0EuStgENPQ2+X9wW6yPWT45mcGMXTAXQzeJVDT0PoOMizZgmOF/efZ1ZGLNNTRsjYOedjUH8YGsq9K8wNxISHcNfMNNYdqqMjQOp0aEPgJOsP1dLbr7j3w+oRz/2EMWl8bKP3hLkJEeGB+ZnsP3ORmuYOs+X4F90dUPEKzLgHQn2vlONoVNS1cbS+/cNrcRfeD0GhxjyIBXmweCKdPf1sOBIYk8baEDjJiwfOMyM9hrzUmJF3yr7FKNpx+JmR9/Fh7pmTjk3g5QMjJZPVOMXR9dB7xXAfWpAX958nNMjGylkfUngpMh6mL4Ujz0Ffj/fEuYm5E+OYnBjFSwHy3deGwAmO1rdTXtvOfXM/pEcEYAsywumqt8PlRu+IcyMTYsK5KSeJvx2sZWAgcCIoPM6hpyF+CmQOW7/Jp+npG2DdoToWFyQTGzlKqvXZH4XOFjixxTvi3IiIcO+8DPaeag2I3FvaEDjBS/vPExIkrJw9hlKUs9aA6oeyFzwvzAOsnptO7aWr7D7VYrYU/6D1FJx+C2Y/ZMmVxDuqmmi90vPhbqFBptxhVAS0qHvo7jnpSICMiLUhcJDe/gFeOVTLHbnJxEeFjn5A0nRImwuHn/W8OA9wZ0EK48KCeWm//98MXuHws4DArAfNVuIUL+4/z4RxYdw0NXH0nYOCjcVyx1+HjibPi3Mz6XERLJicwMsHz/v9mgJtCBzkzWPNXOgYY49okNkPQWMZNJR5TpiHCA8JYvnMVDaV19PZExgRFB5jYAAOPw2Tb4VYB74/PkLz5W7eONbEPXPTCQ4a40/HnI8ZI+Ijz3lWnIdYPTeDMy2d7D9z0WwpHkUbAgd5cf95EqNDuWV60tgPmnGvsabAoqOC1XMz6OzpZ3N5g9lSrM2Zt+HSWctOEq87VEv/gBp9bmwogyNiCy4uA1gyI4WIkCC/nzR2iyEQkSUickxEqkXksWG2PywizSJyyP54ZMi2tSJywv5Y6w49nqKts5ftVY2smp1OyFh7RGCPoFhi3Az91utVz88aT2Z8BC8dOG+2FGtz5DkIjYbc5WYrcYpXDtUyMyOWnOQR1g6MxMyPQMMRaD7mGWEeJDosmKUzUnjtSJ1fr7J32RCISBDwK2ApkA+sEZHh1sw/p5SabX/83n5sPPBd4DqgGPiuiPhs5etN5fX09ivuHssk8bXMWgNXmuDkdvcL8zAiwuo5GbxzsoW6S9bLqOoT9HUbSQjz7rLk2oHqpg7Ka9tZ5cx3v2A1iM3CARMZXO7qY9tR60X+jRV3jAiKgWqlVI1Sqgd4Flg1xmPvBLYqpVqVUheBrcASN2jyCOsO1ZGdGMWM9A9ZOzASOSUQEW/hmyEdpeDVw4GxwMbtnNhqFG2ZcZ/ZSpxi/eE6RDBqbjjKuGRjTU3ZC0YhHouxYEoCqbHh/M2P3UPuMATpwLkhr8/b267lXhE5IiIvikimg8ciIo+KSKmIlDY3N7tBtmM0tnex+1QLK2elIc6E/QWFQP4qqNoIPdaLS56UEMWszDjWa0PgHOUvQmQiTL7FbCUOo5Ti1cN1LJicQHJMuHMnKbwfLp6G86Vu1eYNgmzCipmp7DrRzKVO6y2OGwvemix+FchSSs3E6PU7nIlNKfW4UqpIKVWUlOTARK2bePVwHUrBytlpo+88EjPuNVaUHt/sPmFe5K6ZqVTUtXNSp5xwjO7LcGwTFNxjdAgsRlltG6cuXGHlLBe++3l3QVCYUZ/ZgqyclU5vv+L1Cv8MmHCHIagFMoe8zrC3vYdSqkUp1W1/+Xtg3liP9RVePVxHQVoMU5JcKCAyaSFEp0D5S+4T5kVWzExDRLuHHKZqI/R1GXn6Lcj6Q3WEBAlLZzjhFhokPMYImCh/2ZIBEzPSY8hKiPTbEbE7DME+IEdEskUkFHgQeF9pLhEZ+g1aCQxWdn8dKBGR8fZJ4hJ7m09x+sIVDp9vY5UrowEwUk7MWG0subdg0Y6U2HCKs+LtoyPr+XpNo+wFiJ0IGcVmK3GY/gHFq0fquGXahNFTSoxG4Ueg8wKc2ukWbd5ERLhrVhrvnmyh6bL1sgmPhsuGQCnVB3wJ4wf8KPC8UqpCRL4vIoM5dv9ZRCpE5DDwz8DD9mNbgf/AMCb7gO/b23yKwV7AipkuGgIw3EP9PXD0NdfPZQIrZ6dxsvkKR+t1GcsxceUCnNxhdABs1lu2s/dUK43t3a53ggByFhsVzI5YM2Bi5aw0BhRsKvM/95BbvplKqY1KqWlKqSlKqR/a276jlFpvf/5NpVSBUmqWUuo2pVTVkGP/qJSaan/4XG1HpRTrDtVSnBVPWlyE6ydMnwdxkyzrHlo6I5Vgm/jtENntVL5irKwtvN9sJU6x/nAtkaFBLMpLdv1kwWH2gInXoNd6Ycg5yePITRnnl99963VRvExlfTsnm6+4Nkk8FBFjVFCz0+gtWoz4qFBuzEnU7qGxUvYSJOVCcoHZShymp2+AjWUNlOQnExEa5J6TFtwDPR1Qvc095/Myd81KY/+Zi5y/aL3Ivw9DG4JRWH+4jmCbsKzQhYmya5lxr9FLrHzFfef0InfNTKP20lUOnL1kthTfpq3WKEk54z5LZhp960QzbVd73dcJAsi62VhPU2Hd7z7gdyVctSH4EJRSbCyrZ+HUxLFlGh0ryQVGL7H8Zfed04uUFCQTGmzT0UOjcdQeMzFjtbk6nGRjWQMx4cHcONWN4dpBwUYo6fHNlnQPTUyIZFZmHK/6WeUybQg+hIq6ds61XmV54TDF6V1h0D105m2j12gxxoWHcPv0CWwoq6dfF6wZmcp1kFwICVPMVuIwPX0DbK1sYHF+yvA1uV2h4G5ru4dmplJe2+5XJVy1IfgQNpbVE2QTFue72RCA4SsFY+LMgiyfmUrz5W6/T8/rNO31cHa3MTlqQd4+eYH2rj6WubsTBJZ3Dw1GD24s8x/3kDYEIzDoFlowOcG9bqFBEnNgQr7Ra7Qgt+VOICzY5lc3g1s5+iqgLGsINpXVEx0WzI05YyhA4ygWdw+lxIYzd2Icm/woLbs2BCNQ1XCZ0y2dLPVEj2iQ/FVw5h1L1jOODgvm5mlJbC5v0PWMh6NyHSTlQdI0s5U4TG//AFsqG1mUN4GwYDdFC12Lxd1DywqNdCtnW/wjekgbghHYVFaPTYxSjR4jbyWgLOseWlaYQkN7F4fO6+ih99HRZMz/FNxtthKn2F3TwqXOXvdGyl2Lxd1Dg78Lm8r9Y0SsDcEIbCxv4LrsBBKjwzx3kQl5kJBjWffQHXnJhAQJm7R76P0cXY+V3UIby+qJCg3i5mkeTO5ocfdQZnwkMzNi2egn7iFtCIbheONlqps6PDNRNhQR48fi9N8tubgsJjyEG6cmsrGsQS8uG0rlOkicZoQIW4y+/gFer2jk9rxkwkM85BYaxOLuoSUzUjh87hK1flCsSRuCYdhYVo942i00SP4qY3FZ1QbPX8sDLC1MpfbSVcpr282W4htcuWAY9vxVllxEtvdUK61Xelg2wwvffYu7hwazsfpDLW9tCIZhU1kD8yfFM8HZIhyOkFII47P+sfjIYpTkJxNsEzb6ia/UZapeAzUA+dacH9hQVk9ESBC3Tp/g+Yu9zz1kvYye2YlR5KaMY7MffPe1IbiG6qYOjjVe9my00FAG3UM1O+Gq9WLy4yJDWTAlgU1l9do9BEbvNn6KJXML9Q8YhVduz53gvtxCo5G30nAPnXrTO9dzM8sKUyk9c5GmdusZsqFoQ3ANg9bdpSIcjpK/Cgb6jCpWFmTpjFROt3RS1RDgqak7W+HULsu6hfadbuVCR4/3OkEA2TdDWIxlR8TLClNQCstXLtOG4Bo2ljUwb9J4UmK94BYaJG0uxGZCpTVvhpKCZGyCjh6q2mDM91g0WmhTWT1hwTZu84ZbaJDgUMgpMTpBFqxcNnXCOKZOiGajxWsUuMUQiMgSETkmItUi8tgw278qIpX24vXbRWTSkG39InLI/jD1l/BMyxUq69tZ6o2JsqGIGEPkk9uhy3qTronRYRRnx/tNKJ3THF1v1JpInWW2EocZGFBsrmjg1ulJRIUFe/fieXdBZwuc2+3d67qJZTNS2HOqhZaO7tF39lFcNgQiEgT8ClgK5ANrRCT/mt0OAkX24vUvAj8dsu2qUmq2/bESE9lSYazw9Uq00LXkrzIqlx33uUqdY2JZYSrVTR2caAxQ91D3ZWOeJ+8uS7qFjtS20djebc53f+oio7C9Rav2LZmRyoCCLZXWyxAwiDtGBMVAtVKqRinVAzwLvG9srJR6Qyk1uBZ7N0aRep9jS2UD+akxZMZHev/iGfMhOtmyq4xL7In5rHwzuET1NsOQ5y43W4lTbKloIMgm3J7rRbfQIGHRMOV2e8SV9QIO8lLHMSkh0tLzBO4wBOnAuSGvz9vbRuLTwNBZ0XARKRWR3SIyYsydiDxq36+0ubnZNcXDcKGjm9IzFykpcENJPmew2WD6MuMHxYKhdCmx4czKjGOLhW8Gl6jaAJEJkHmd2UqcYktlI9dlxxMX6YEEi2MhbwW0nYP6Q+Zc3wVEhJL8ZN6pbuFyV6/ZcpzCq5PFIvIxoAj4zyHNk5RSRcBDwM9FZNjk7Uqpx5VSRUqpoqQk9y993360EaX+0bM1hdwV9lC6XeZpcIGS/GQOn2+jvs36Ky0doq8Hjm+B6UvB5qWwSzdS09xBdVMHJfkmdYIApi0FCbKse6ikIIWe/gHePO7+Tqo3cIchqAUyh7zOsLe9DxFZBHwLWKmUem9WRSlVa/9bA+wE5rhBk8NsqWgkPS6CvNRxZlzeYDCUrupV8zS4wKB/eVuguYfO/B262wxDbkG22j+vxWbMDwwSlQCTFlrWNTp34ngSo0Pfm2e0Gu4wBPuAHBHJFpFQ4EHgfdE/IjIH+C2GEWga0j5eRMLszxOBG4BKN2hyiCvdfbxVfYGSgmTEzIm+4FDIWWyE0g30m6fDSaZOiGZyUhSvW/RmcJqqDRASCZNvNVuJU2ypbGRGegzpcRHmCsm7C5qr4MIJc3U4QZBNWJSXzBtVTfT0DZgtx2FcNgRKqT7gS8DrwFHgeaVUhYh8X0QGo4D+E4gGXrgmTDQPKBWRw8AbwI+VUl43BLuON9PTN2CuW2iQ3OVwpRnO7zNbiVOU5Kewu6aFtk5r+kodRimo2ghT74AQk39InaDpchcHzl5kcZ6PfPfBXtTHepQUJHO5u493a1rMluIwbpkjUEptVEpNU0pNUUr90N72HaXUevvzRUqp5GvDRJVS7yilCpVSs+x//+AOPY6ytbKRuMgQ5meNN6TrJEwAACAASURBVOPy72fqYrCFWHaIXFKQTN+A4o1jTaPv7A/UHYTLdZZ1C20/2mTMjZkVJDGU2AxjcaVFv/sLpyQSGRpkyYCJgF9Z3Ns/wPaqJu7ITSY4yAf+HeExMPkWY9LMgqF0szPiSBoXxpZK690MTlG1wZjkzCkxW4lTbKloIDM+gtwUE+fGhpK3Amr3Q9sHphl9nvCQIG6dnsTWykbLVe3zgV8+c9l3qpW2q72+0SMaJHc5XDwFTUfNVuIwNpuwOD+Zncea6eq13jyHw1RtMCY5I+PNVuIwHd19vF3dQkl+irlzY0PJvcv4a9G07CX5KTRd7uawxar2Bbwh2FLZSHiIjZtzPFiNyVGmLwPEwjdDMp09/bxz0nrFdhyi5SQ0H7WsW2jX8WZ6+gfMDRu9lqRpkDjdspFzt02fQLBNLLewMqANgVKKLRUN3JST5L20u2NhXIqx0tiivtIFUxKIDgu2bCjdmBk01LnLzNXhJFsqGoiPCmXeJB+YGxtK3go4/baRzdVixEaGcP3kBMvNEwS0Iaioa6eurcu3ekSD5C43Vlm2nTdbicOEBQdxW+4EtlY20m8xX6lDVG2AlJkQN9FsJQ7zj7mxCb4xNzaU3BVGFleL5t0qKUjmZPMVqps6zJYyZnzsG+BdtlQ0YBOjCLvPMehuqNporg4nKclPpuVKDwfOWq/YzpjoaIJzeyzrFtpT08rlrj4W+2InKG0OjEuz7Ih4kf33ZKuF3EOBbQgqGynKiic+yqT8Kh9G4lRL+0pvnZ5ESJBYbog8Zo5tApR13UKVDYSH2LjJl+bGBhExRsTV26Gnc/T9fYy0uAhmZsRaKnIuYA3BWXtFLZ90Cw2Su9yyvtJx4SEsnJLIlspG/yxhWbXBcAklzzBbicMopdha2cjNvjY3NpS8FdB3FWreMFuJU5TkJ3Pw7CUaLVLCMmANwaC19onVxCMx6Cs9scVsJU5RUpDMmZZOjjdax1c6Jro7jNoDuSssWXugvLad+rYuSszMLTQak26A8FjrRs7Z/7dWcQ8FsCFoJDdlHBMTTKg9MFbS5sC4VMv6ShfbfaV+5x46uR36u61be6DSPjdmRu2BsRIUAtOWWLaEZc6EaLISIi0TRhqQhqClo5vS062+3SOCITUKtkOv9VI7T4gJZ87EOMvcDGOmagNExEPm9WYrcYotFY0UZ8cz3hfnxoaSuxyutsLZd81W4jAiQklBCu+evEC7BWoUBKQh2F7VxIDCt+cHBslbAb2dcNKqvtIUymrbqL1kPUM2LP29cHyzUXsgyMu1fd3A6QtXONZ42bddooNMXQTB4dZ1D+Un09uv2HnM92sUBKQhGKw9UJAWY7aU0Zl0o1Gj4JhFbwZ76g6/qVFw5m3oarOsW+i92gNW6ASFRsHk2wxDYMGAgzn2GgVWKGEZcIbgak8/f69uZnG+ybUHxkpwqJHQzKI1CqYkRTMlKcpSoXQfStUGCI4wfqAsyJbKBvLMqsvtDLnLoe0sNJSZrcRhBmsUvHmsme4+3753A84Q7DrRTFevj+VXGY3c5dDZYixgsiAlBSnsrmm1fo0CpQxDMOV2CLXID+kQLnR0s//MRWt996cvBbFZNmCipCCZju4+3j3p2zUKAs4QbKloJDYihPnZFsoWOXURBIVa2lfaP6DYcczi7qH6Q9Bea1m30I6j9rkxX8q0OxpRiTBxgWW/+4M1Cnw9jNQthkBElojIMRGpFpHHhtkeJiLP2bfvEZGsIdu+aW8/JiJ3ukPPSPT1D7C9qpE7cicQ4mv5VT6M8BjIvsWyvtJZGXFMGBfm8zfDqFRtMHqn05aYrcQptlQ2kB4XQX6qBebGhpK7HBrLofWU2UocJjwkiFum+X6NApd/DUUkCPgVsBTIB9aISP41u30auKiUmgr8DPiJ/dh8jBrHBcAS4Nf283mEfacvcqnTx2oPjJXcZZauUbDIH2oUVG2EiQuNQusW40p3H7tO+EBdbmeYbk/jccyiebcKkn2+RoE7usXFQLVSqkYp1QM8C6y6Zp9VwJP25y8Cd4jxbVwFPKuU6lZKnQKq7efzCFsqGwgLtnHzNB/MrzIagzeDRYfIlq9R0FoDTRWWdQu9dcKH6nI7Sny2kcrjqDXnCW6fnkyQj9cocIchSAfODXl93t427D72YvdtQMIYjwVARB4VkVIRKW1udi4ut6dvgEV5yUSGWi/+W9coMJnBLLBWTTJX4UN1uZ0hdwWc2w0dvh+Tfy2xkSFclx3v065RyzjKlVKPK6WKlFJFSUnO9eh/eE8hv3xojpuVeRGL1yi4dXoS245atEZB1QajVzo+y2wlDtNnrz1wuy/WHhgructBDRiL+SxISX4y1U0dnGz2zbxb7vhW1AKZQ15n2NuG3UdEgoFYoGWMx7oVy/lHhzLd7pY4tslcHU5SUpDChY4eDp2zWI2CKxeM3qhF3UJ7T9vrclvRLTRISiHETrSsa3Sxjyehc4ch2AfkiEi2iIRiTP6uv2af9cBa+/P7gB3KyE28HnjQHlWUDeQAe92gyT9JmgYJOZZ1D/2jRoFv3gwjcnyz0Rudbl23kDE3lmi2FOcZrFFwcoeR/dVipMdFMCM9xmcTMLpsCOw+/y8BrwNHgeeVUhUi8n0RWWnf7Q9AgohUA18FHrMfWwE8D1QCm4EvKqUsHFbiBXKXw+m/w1WL9aqBmHCjnuvrFQ3WqlFQtQFiMiB1ltlKHGaw9sBNOYnWnBsbSt4KI+vrye1mK3GKkvwUDp67RNNl36tR4BaHoVJqo1JqmlJqilLqh/a27yil1tufdyml7ldKTVVKFSulaoYc+0P7cdOVUtb0eXiT3BUw0AcntpqtxClKClI43dJpnXquPfaEf7nLLFl7oLK+ndpLV63tFhok83oj66tV3UP5ySgF2482mS3lA1h05iiASZ8H0cnWvRkGaxT4qK/0A5zcYVTKsuj8wJaKRkTg9jwfrj0wVoKCjZQTxzcbWWAtRm7KODLjI3zSPaQNgdWw2YyboXob9PreEHM0UmLDmZVpoRoFxzYalbIm3WC2EqfYWtlI0aTxJEaHmS3FPeQuN7K/nv672UocRkQoyU/h7eoWOrp9q9iONgRWJHcF9HTAqV1mK3GKkvxkDp+7REObjxuy/j4jQivnTqNilsU419pJZX27f7iFBplyO4REWnZEXJKfTE//AG/6WI0CbQisSPbNEBpt2eihO+0pPrYe9fFRwbk9RoUsiy4is1TtgbESEmEYA4vm3Zo3aTzjI0PY6mNp2bUhsCLBYZCz2F6jYMBsNQ4zJSma7MQon/SVvo+qDUbW16mLzFbiFFsrG5mWHE1WYpTZUtxL7gq4XAd1B8xW4jDBQTbuyEtme1UTvf2+c+9qQ2BVclfAlSaoLTVbicMYvtJkdte0+G49V6WMqnDZt0DYOLPVOMzFKz3sPd3qX26hQabdCRJkaffQ5a4+9tS0mi3lPbQhsCpTF4Et2LLuoZICH6/n2lQJF09bNlpoR1UT/QPKv9xCg0TGQ9YNljUEN+UkER5i86mqfdoQWJWIOMi6ycjIaEFf6exMI5LFZ91DVRsBsexq4q2VjaTEhFOYHmu2FM+QuwKaq+BCtdlKHCYiNIibc4waBb6ysFIbAiuTuxxaT8KF42YrcZggm7A4fwI7fbWea9VrkFEE46zXo+7q7efN40ZdbpvNeovgxsR7admtOSJenJ9MfVsX5bXtZksBtCGwNn5wM/hkPde280aWV4u6hf5+4gJXe/v90y00SFymkfLDou6hO/KSsQk+4x7ShsDKxKZD2px/5Mq3GD5bz3Uwu+t0axqCrZWNjAsL5vrJ1quk5hC5d8H5fXDZN35MHSE+KpT5WfE+k4BRGwKrk7vciBxqrzdbicOEhxg1CnyunmvVBiPLa9I0s5U4TP+AYtvRRm7NnUBosJ/f3rnLAWXptOzHGi9zpuWK2VK0IbA8uSuMv1at55qf4lv1XK9egtNvWXYR2YGzF2m50kOJP7uFBpmQB+OzLesaHfyMfGFErA2B1UnKhfjJlvWV3jZ9AsG+VM+1epuR3XXQwFqMrZWNhAQJt063YF1uRxmsUVDzJnT5xqSrI2TGR5KbMs4n3EPaEFidwZvh1C4jGZfFiI00ahT4TBhp1WsQNQHSi8xW4jBKKV6vaGDBlETGhVsvN5JT5N0FA71Qbd207KVnWrnQ0W2qDm0I/IHpy+03wzazlTjF4vxkTjZfMb+ea183nNhmZHe1We/WONHUwZmWzsBwCw2SMR+ikiw7Ii7JT2ZAwQ6TaxS49G0XkXgR2SoiJ+x/xw+zz2wReVdEKkTkiIg8MGTbn0TklIgcsj9mu6InYMkshshEy94Mi33FV3pqF/RctmzYqF8mmRsNW5C9RsEWw5BbjIK0GNLjIkx3jbra7XkM2K6UygG2219fSyfwCaVUAbAE+LmIxA3Z/nWl1Gz745CLegITi98MaXERFKbHmu8eqlwHoeNg8q3m6nCSzeUNzMqMIzkm3Gwp3iV3hWHAT71lthKHEREW5yfz1olmOnvMq1HgqiFYBTxpf/4kcPe1OyiljiulTtif1wFNQADMZHmZwZvhtPVuBjCGyAfPXaKp3aQaBf19RuTVtDuN7K4W4/zFTspq21g6ww+TzI1G9i2WTstekp9Md98Au45fME2Dq4YgWSk1GMDeAHzomFREioFQ4OSQ5h/aXUY/E5ER70AReVRESkWktLnZRxOVmcnkWyAkyrLuoZKCFJSCbWb5Ss++A50tkL/SnOu7yOZyYzQVkIYgJNxIwnhsoyXTss/Pjic2IsTUVcajGgIR2SYi5cM8Vg3dTxnZk0ZcFSQiqcBTwCeVUoOf1jeBXGA+EA/860jHK6UeV0oVKaWKkpL0gOIDhETA1NstW6NgWnI0kxIizSvYUbkegiMsW3tgc3kDeakxTErws9oDYyV3BXQ0WjIte0iQjTtyJ7Cjqok+k2oUjGoIlFKLlFIzhnmsAxrtP/CDP/TDdudEJAbYAHxLKbV7yLnrlUE38ARQ7I43FbDkroDL9VB30GwlDiMiLM5LNqee68CA4VaYegeEWu+HtKm9i/1nLwbmaGCQnMWWT8t+qbOXfacvmnJ9V11D64G19udrgXXX7iAiocDfgD8rpV68ZtugERGM+YVyF/UENjkl9oIdVr0ZUsyp51pbahjQ/FWj7+uDvF7RgFIB6hYaJCLOKOFq0bTsN+UkERps43WTAiZcNQQ/BhaLyAlgkf01IlIkIr+37/MR4Gbg4WHCRP8qImVAGZAI/MBFPYHNYMGOo69a8maYN2k8CVGhbPb2zVC5DmwhxkSxBdlU3sDkpCimTog2W4q5DKZlbz5mthKHiQoL5uacJDaXN5iSd8slQ6CUalFK3aGUyrG7kFrt7aVKqUfsz/+ilAoZEiL6XpioUup2pVSh3dX0MaWUySuK/ID8VdBywqiwZTGCbEJJQQo7jjbS1eulGgVKGYZz8q0Qbr0iLq1XethzqpWlM1IwBtYBjMXTsi+fmUJDexcHz3k/75b1lk9qPpy8lYAYvVwLsrwwlSs9RmEVr9BwBC6dsWy00NbKBvoHFEtnpJotxXxi0iB9nmUj5+7ISyY0yMbGMu9nEtaGwN+IngCTboCKV8xW4hTXT45nfGSI926GyvXGvIpFaw9sKm8gY3wEBWkxZkvxDXJXQN0BaKs1W4nDxISHcFNOIpvK6r3uHtKGwB8puBsuHIOmo2YrcZjgIBt3FqSw/WiTd9xDR1815lWirFfEpe1qL29XX9BuoaFYPC37ssJU6tq6vJ6WXRsCf2TQPWTRUcGywlQ6uvvY5Wn3UPMxw2DmWdMt9EZVE739iiXaLfQPkqYZRYUsOk+wKD+ZkCDxuntIGwJ/ZFwyTFoIldY0BAumJBDnDfdQ5Xrjr0VrD2wqryc5Jow5mXGj7xxI5K2A03+Hq+bE5LtCbEQIN05NZGNZA8qLkX/aEPgr+XdDcxU0VZmtxGFCgmzcmZ/CNk+7h46ug4xiiLFej7qzp483jzezpCAFm027hd5H7l1GcSGLlrBcVphK7aWrHDnvvfoi2hD4K3l3YUQPWXNUsGym4R5664SHEnFdqIaGMii4xzPn9zA7qpro6h3QbqHhSJ8LcROh/GWzlTjF4vxkgm3edQ9pQ+CvxKTCxOstG0a6cEoCsREedA9VvAyIMbFuQV47XE/SuDCKs+PNluJ7iBgGvuYN6Gw1W43DxEWGcsPURDaW13vNPaQNgT+Tf7exsKz5uNlKHCYkyEZJfjLbKhvp7vOAe6j8JZi4wIg9txiXu3p541gTywtTCdJuoeEpWG24h46uN1uJUywvTOVc61XKa71Ti1kbAn9mcJGUhd1Dl7v7+Lu73UONlcb8yYzV7j2vl9h2tJHuvgHumqXdQiOSOgviJ1vaPRRkEzZ4yT2kDYE/E5MGmddbNoz0himJxIQHu/9mqHgZxGbZJHOvHq4nPS6COZkfqAyrGUQEZtxrFGrqMLcesDOMjwpl4ZQENnnJPaQNgb9TcDc0VcCFE2YrcZjQYBslBSlsdad7SCmjl5h1k7EK22Jc6uzhrRPNLJ+ZqqOFRqNgNagBy86TLStM5UxLJxV1nncPaUPg7+SvAsTwiVuQ5TNTudzV574yfg1HjAyVFnULvV7RQG+/YsVM7RYaleR8SMqFir+ZrcQplhSkEBIkrDvk+XQZ2hD4OzFpkHUjlL1gydTUN05NJD4q1H03Q/nLRgETi64mfu1IPZMSIilMt16mVFMoWA1n3oF27ydyc5XxUaHcMi2J9Yfr6Pdw7iFtCAKBwvugpRrqD5mtxGFCgmwsL0xl29FG1yuXKWXMD0y+1ajdYDEudHTzdvUFVsxM1bmFxsqM1YCybMDEytnpNLZ3s/eUZ8NgXTIEIhIvIltF5IT977CzVyLSP6Qozfoh7dkiskdEqkXkOXs1M427yV9lFF458oLZSpxi1ew0unoH2OJqwZraA3DprNFLtCCbyhsYUHDXLOuFvJpGYg4kF1rWNboobwKRoUGsP+xZ95CrI4LHgO1KqRxgu/31cFwdUpRm6Jj8J8DPlFJTgYvAp13UoxmOiPFGGcvyl2DASwVf3Mi8SePJGB/BukN1rp2o7AUICjUqWVmQVw/XMXVCNNOTx5ktxVrMuAfO74OLZ8xW4jCRocGU5CezsazBM+tp7LhqCFYBT9qfP4lRd3hM2OsU3w4M1jF26HiNg8y8HzoajHA6iyEirJyVxt+rL3Cho9u5k/T3QfmLMG2JUd/WYtReusreU62snJWm3UKOMuM+42+ZVUfE6bRd7XVfwMQwuGoIkpVSg7MwDUDyCPuFi0ipiOwWkcEf+wTgklJq0PF7Hkgf6UIi8qj9HKXNzV4ubu4PTFsCoeMsezPcPSed/gHFhiNOTvrVvAFXmmHmA+4V5iVeOWi4Bu6ZM+ItohmJ8ZNg4kI48pw1AyZyjICJVzwYPTSqIRCRbSJSPszjfatxlLHqYaT/8iSlVBHwEPBzEZniqFCl1ONKqSKlVFFSUpKjh2tCIoxEdJWvQm+X2WocZlryOHJTxjl/Mxx+9h8uMouhlOJvB2uZnzWezPhIs+VYk1kPwIXjRvUyi/FewESlGwImRmBUQ2AvSj9jmMc6oFFEUgHsf4ddwqeUqrX/rQF2AnOAFiBORILtu2UA1qsvZyUK74PuNjixxWwlTnH3nHQOnr3E2ZZOxw7svmzUsS1YDcHWi0cor22nuqmDe+ZkmC3FuuTfDUFhcPg5s5U4xarZaXT3uSFgYgRcdQ2tB9ban68FPrCET0TGi0iY/XkicANQaR9BvAHc92HHa9xI9i0QNcGy7qHBaBmHIygq10PfVcu6hV46cJ5Qe69Q4yQRcTB9iREw0d9rthqHmTtxPOlxEbziasDECLhqCH4MLBaRE8Ai+2tEpEhEfm/fJw8oFZHDGD/8P1ZKVdq3/SvwVRGpxpgz+IOLejQfRlCwEVd9/HVLVm9Kj4ugOCuelw/WOpZ/5chzMD4bMos9J85D9PYP8OrhOu7Im0BsZIjZcqzNzAeh8wJUbzdbicPYbMKq2Wn8/UQzTZfd79p1yRAopVqUUncopXLsLqRWe3upUuoR+/N3lFKFSqlZ9r9/GHJ8jVKqWCk1VSl1v1LKyZAQzZiZ/RD0d0PZi6Pv64PcOy+dmuYrHDg7xuLebbVwapcxGrBgtM1bJ5ppudLD6rnaLeQyUxdBRDwcedZsJU6xem4G983LoLff/RPeemVxoJE6y1hgc+ivZitxiuUz04gICeLF/efGdkDZC4CCmR/xqC5P8fKBWsZHhnDLNB0g4TLBoUZG0qqN0OW9MpDuYuqEaH563yzS4yLcfm5tCAKROR+FuoPQWGG2EoeJDgtmWWEqrx6up7NnlAgKpYxooYz5kOBwoJrptHf1srWykbtmpREarG9VtzDrQWNEbNHU7J5Cf7sCkcKPGCknDlpzVPCRogw6uvvYXD5KBMX5Umg+CnM+5h1hbua1w/V09w3otQPuJH0eJORYdkTsKbQhCESiEmD6UmMS1YIRFMXZ8UxKiOT50lHcQwf/DCGRls0t9Ny+s0xPHsfsTOuthPZZRGDuJ+DcHmiqMluNz6ANQaAy52NGBMXx181W4jAiwv3zMthd0zrymoLuy1D2kmEEwmO8K9ANVNa1c/h8Gw8WZ+qUEu5m1hojFfnBp8xW4jNoQxCoTLkDolPg4F/MVuIUq+dmIMLIk8YVf4PeK0bvz4I8t+8socE27RbyBNFJMH0ZHH4G+nSgImhDELgEBRvL7k9sgcuNZqtxmLS4CG7KSeLF/eeHL9px4ClInG7JtQNdvf387WAtS2ekEBdpvZXQlmDuWuhsMVaca7QhCGjmfBxUPxyy5qjgI0UZ1LV1sevENUkIm47C+b0w9+OWXDuwqbye9q4+HpifabYU/2XKbRCbCQf+bLYSn0AbgkAmMcco4l76J0vWKSjJTyExOpS/7r4mz/yBp4yoqJkPmiPMRZ7de46shEgWTE4wW4r/Ygsy5slq3oCLp81WYzraEAQ68x+BtrNwYqvZShwmNNjGg/Mnsr2qiXOt9knj3qtw+GkjKiraeouwapo72HOqlY/M15PEHmf2RwExOg4BjjYEgU7ucmPSeN/vR9/XB1lz3UQEeGbvWaOh/CUjj1LxZ0zV5SxP7T5DsE24b55OKeFx4jKNtOQHngz4SWNtCAKdoBCY9zBUb4PWGrPVOEx6XAR35CXz3L5zdPf2wZ7fQlKe4fKyGFe6+3ix9DzLClOZMC7cbDmBwXWPGgWLKgM78bE2BBrDEIgNSp8wW4lTfPz6SbRc6WHPW69DwxFjNGBBt8rLB85zubuPtQuzzJYSOEy+HRKmGh2IAEYbAg3EpELeCmOBTe9Vs9U4zI1TE8lKiET2Pg5hMZasO6CU4sl3z1CYHsvciXolsdew2WD+Z6C2FGr3m63GNLQh0BjMf8TwrZe/bLYSh7HZhM/MieK6q2/RknMfhEWbLclh3q5uobqpg7ULs/QksbeZ/RCERsOex81WYhraEGgMsm6CCQWw+9eWLPB9L9sIlX5+23m72VKc4k/vnCY+KpQVM3UVMq8THmOknah4GTqaR9/fD3HJEIhIvIhsFZET9r/jh9nnNhE5NOTRJSJ327f9SURODdk22xU9GhcQgQVfhMZyqNlpthrH6L1K+IE/cCJmIX+sCqLukrXcW6cvXGF7VSNrijMJDwkyW05gUvwo9PdAaWAWSXR1RPAYsF0plQNst79+H0qpN5RSs5VSs4HbgU5gaPX0rw9uV0odclGPxhUK74PoZHj3l2YrcYzDz0DnBeIWfw0FPPH2KbMVOcTjb9UQEmRj7YIss6UELknTIOdO2Pu4JefJXCXYxeNXAbfanz8J7MSoQzwS9wGblFIjpIzUmEpwmBFxs+MHsO6LEPGBAZ5vUvEKpM0hacYdLC88xDN7z/FPd+QQE+77NX6bLnfx4v7z3Ds3gwkxOmTUVG74Z/jTcnjhYWPVva+y8MtuXyzpqiFIVkrV2583AMmj7P8g8H+vafuhiHwH+4hipLrFIvIo8CjAxIkTnVes+XCKHzUScVlp0tgWDMv/G0R49ObJrD9cxzN7zvLZW3y/KtkTb5+mt3+AR2+ebLYUzaQbjIizo68ada59lblr3W4IRI0yMSgi24CUYTZ9C3hSKRU3ZN+LSqlhu5EikgocAdKUUr1D2hqAUOBx4KRS6vujiS4qKlKlpaWj7aYJUB763W5ONnew6xu3ERbsuz73y129LPzxDm7KSeTXH51nthxNACAi+5VSRde2jzpHoJRapJSaMcxjHdBo/zEf/FFv+pBTfQT426ARsJ+7Xhl0A08A1ssZrPE5vnjbVBrbu3lu3xgL3JvEX3af5XJXH5+zwMhF49+4Olm8Hlhrf74W+LB12muAZ4Y2DDEiAtwNlLuoR6Nh4ZQE5meN51dvVNPV65tZVS939fL4rpPcMi2JmRl6AZnGXFw1BD8GFovICWCR/TUiUiQi72UxE5EsIBN485rj/yoiZUAZkAj8wEU9Gg0iwlcWTaOxvZtnB5PR+RhPvH2ai529fK1kmtlSNBrXJouVUi3AHcO0lwKPDHl9GvhAzT2llDVX/2h8ngVTEijOjufXO0/yYPFEn4rPb+vs5Xdv1bA4P1mPBjQ+gV5ZrPFLBkcFTZe7eerdM6Mf4EV+91YNl7v6+OpiPRrQ+AbaEGj8lgVTErhlWhK/2HGC1is9ZssBoLG9iz++fYrlM1PJS40xW45GA2hDoPFzvr08j86efv5n23GzpQDw083H6OtXfOPO6WZL0WjeQxsCjV+TkzyONcWZ/GXPWaqbOkzVcuT8JV46cJ5P3pjFpIQoU7VoNEPRhkDj9/zLomlEhgTxgw2VjLaA0lMopfj+q5UkRofypdummqJBoxkJbQg0fk9idBhfXpTDzmPNbCirH/0AD/DsvnOUJ1agrwAACLVJREFUnrnI1++czjgL5EDSBBbaEGgCgocXZlGYHsv31lfS1tk7+gFupLG9ix9tPMp12fHcPy/Tq9fWaMaCNgSagCA4yMaP7y3kYmcPP9xY6bXrKqX4t1fK6ekb4Mf3zsRm09XHNL6HNgSagKEgLZZHb57M86Xneb2iwSvX/NvBWrZUNvIvi6aRnagniDW+iTYEmoDiK4umUZgeyzdePEKthyuZ1TR38O1XyinOjuczN2V79FoajStoQ6AJKEKDbfzvmjn09Q/w5WcO0ts/4JHrXO3p50tPHyQ02Mb/PDib4CB9q2l8F/3t1AQcWYlR/Gh1IaVnLvJvr5S7PaR0YEDxtRcOcbShnf/7kVmkxka49fwajbtxtUKZRmNJVs1O50RjB798o5pJCVF8/lb31QT4zy3H2FjWwLeW5XF77mhF+zQa89GGQBOwfHXxNM60dvKTzVWEBdv41I2u+/F/vu04v9l5kjXFE3lEzwtoLII2BJqAxWYT/vv+WfT2DfD91yrp6uvn87dMwaiT5BgDA4r/2nKMX+88yX3zMvjB3TOcOo9GYwZ6jkAT0IQG2/jfh+awclYaP918jH9+9hCdPX0OnaPtai9f+OsBfr3zJGuKM/nJvTMJ0usFNBbCJUMgIveLSIWIDIjIBwoiD9lviYgcE5FqEXlsSHu2iOyxtz8nIqGu6NFonCEkyIjs+fqd03ntSB1Lfv4Wbx5vHnUSWSnF6xUN3PmzXWw92si3l+fxo3sKtRHQWA5XRwTlwGpg10g7iEgQ8CtgKZAPrBGRfPvmnwA/U0pNBS4Cn3ZRj0bjFCLCF2+bytOPXE+wTVj7x73c+5t3eHH/eS50dL9v36b2Lp7ec5ZVv3qbzz61n3Hhwbz8+YU8ctNk7Q7SWBJXS1UeBUb78hcD1UqpGvu+zwKrROQocDvwkH2/J4HvAb9xRZNG4woLpiSw8cs38XzpOX73Vg3/3wuHARgfGUJ0eDBXuvvfK3IzOSmKn947k3vmphOi1wloLIw3JovTgXNDXp8HrgMSgEtKqb4h7R+oazyIiDwKPAowceJEzyjVaIDwkCA+sSCLj18/ifLadt45eYHTLZ109fYTHhJEzoRoirPjKUiL0SMAjV8wqiEQkW1AyjCbvqWUWud+ScOjlHoceBygqKjInKTymoBCRCjMiKUwI9ZsKRqNRxnVECilFrl4jVpgaO7dDHtbCxAnIsH2UcFgu0aj0Wi8iDccm/uAHHuEUCjwILBeGSEZbwD32fdbC3hthKHRaDQaA1fDR+8RkfPAAmCDiLxub08TkY0A9t7+l4DXgaPA80qpCvsp/hX4qohUY8wZ/MEVPRqNRqNxHDGrhqsrFBUVqdLSUrNlaDQajaUQkf1KqQ+s+dIxbxqNRhPgaEOg0Wg0AY42BBqNRhPgaEOg0Wg0AY4lJ4tFpBk44+ThicAFN8qxAvo9Bwb6PQcGrrznSUqppGsbLWkIXEFESoebNfdn9HsODPR7Dgw88Z61a0ij0WgCHG0INBqNJsAJREPwuNkCTEC/58BAv+fAwO3vOeDmCDQajUbzfgJxRKDRaDSaIWhDoNFoNAFOQBkCEVkiIsdEpFpEHjNbjzsQkUwReUNEKkWkQkS+bG+PF5GtInLC/ne8vV1E5Bf2/8EREZlr7jtwHhEJEpGDIvKa/XW2iOyxv7fn7GnPEZEw++tq+/YsM3U7i4jEiciLIlIlIkdFZIG/f84i8hX797pcRJ4RkXB/+5xF5I8i0iQi5UPaHP5cRWStff8TIrLWEQ0BYwhEJAj4FbAUyAfWiEi+uarcQh/wNaVUPnA98EX7+3oM2K6UygG221+D8f5z7I9HsXaN6C9jpDYf5CfAz5RSU4GLwKft7Z8GLtrbf2bfz4r8D7BZKZULzMJ47377OYtIOvDPQJFSagYQhFHPxN8+5z8BS65pc+hzFZF44LsYZYCLge8OGo8xoZQKiAdGzYTXh7z+JvBNs3V54H2uAxYDx4BUe1sqcMz+/LfAmiH7v7eflR4YFe22A7cDrwGCsdoy+NrPG6MWxgL782D7fmL2e3Dw/cYCp67V7c+fM/+odx5v/9xeA+70x88ZyALKnf1cgTXAb4e0v2+/0R4BMyLgH1+qQc7b2/wG+1B4DrAHSFZK1ds3NQDJ9uf+8n/4OfANYMD+OgG4pIxCSPD+9/Xee7Zvb7PvbyWygWbgCbs77PciEoUff85KqVrgv4CzQD3G57Yf//6cB3H0c3Xp8w4kQ+DXiEg08BLwL0qp9qHblNFF8Js4YRFZ8f+3d/asUQVRGH4ORCOxMGsXiRACkjaxCmghKClSxCadEFF/haTKHxCs/AEiCkqQYCP4URsNiEoU3aBgBI1YpLBKcVLMuXpRCzcb9rIz7wMLO2emOOe+C+/OmWEX2HL3taZz6SEDwEnghrtPAT/53S4AstS5BZwnmeAx4DB/t1Cypxe6lmQEX4DjtfFoxPoeMztAMoFb7r4c4W9mNhLzI8BWxHN4DqeAOTP7BNwhtYeuA8NmNhBr6nX9qjnmjwA/epnwPrAJbLr7sxjfIxlDzjqfAz66+3d33wGWSdrnrHNFp7p2pXdJRvAcOBE3Dg6SDp1WGs6pa8zMSP/1/Nbdr9WmVoDq5sBF0tlBFV+I2wfTwHZtC9oXuPtVdx919zGSjk/c/QLwFJiPZX/WXD2L+VjfV9+c3f0r8NnMJiJ0FlgnY51JLaFpMxuKz3lVc7Y61+hU14fAjJm1Yic1E7H/o+lDkh4fyMwC74ENYLHpfPapptOkbeMr4GW8Zkm90cfAB+ARcDTWG+n21AbwmnQjo/E6uqj/DPAg3o8Dq0AbuAsMRvxQjNsxP9503nusdRJ4EVrfB1q56wwsAe+AN8BNYDA3nYHbpDOQHdLO78pedAUuR+1t4FInOegnJoQQonBKag0JIYT4BzICIYQoHBmBEEIUjoxACCEKR0YghBCFIyMQQojCkREIIUTh7AJvF5/AQFxS3QAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "def clip_sin(x):\n", " x = clip_gradient(-0.75, 0.75, x)\n", " return jnp.sin(x)\n", "\n", "plt.plot(clip_sin(t))\n", "plt.plot(vmap(grad(clip_sin))(t))" ] }, { "cell_type": "markdown", "metadata": { "id": "CICQuI86WK4_" }, "source": [ "### Python debugging\n", "\n", "Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff." ] }, { "cell_type": "markdown", "metadata": { "id": "cgxMjNTrGjJn" }, "source": [ "When trying to track down the source of a `nan` runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with `jax.custom_vjp`.\n", "\n", "We'll defer an example until the next section." ] }, { "cell_type": "markdown", "metadata": { "id": "IC7tEcr1-Fc5" }, "source": [ "### Implicit function differentiation of iterative implementations\n", "\n", "This example gets pretty deep in the mathematical weeds!" ] }, { "cell_type": "markdown", "metadata": { "id": "szAt97t80hew" }, "source": [ "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n", "\n", "For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "2uA8X2izXH2b" }, "outputs": [], "source": [ "from jax.lax import while_loop\n", "\n", "def fixed_point(f, a, x_guess):\n", " def cond_fun(carry):\n", " x_prev, x = carry\n", " return jnp.abs(x_prev - x) > 1e-6\n", "\n", " def body_fun(carry):\n", " _, x = carry\n", " return x, f(a, x)\n", "\n", " _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))\n", " return x_star" ] }, { "cell_type": "markdown", "metadata": { "id": "p2xFQAte19sF" }, "source": [ "This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \\mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$.\n", "\n", "We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "rDDwM8bYYzRT" }, "outputs": [], "source": [ "def newton_sqrt(a):\n", " update = lambda a, x: 0.5 * (x + a / x)\n", " return fixed_point(update, a, a)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "42Ydd7_6aLXU", "outputId": "c576dc92-33df-42b9-b2e8-ad54119514b1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.4142135\n" ] } ], "source": [ "print(newton_sqrt(2.))" ] }, { "cell_type": "markdown", "metadata": { "id": "-yFtYWH13QWm" }, "source": [ "We can `vmap` or `jit` the function as well:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "t_YSXieT3Yyk", "outputId": "76483e18-81f3-47a8-e8aa-e81535c01fe2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 1.4142135 1.7320508 2. ]\n" ] } ], "source": [ "print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))" ] }, { "cell_type": "markdown", "metadata": { "id": "emwWIt3d3h1T" }, "source": [ "We can't apply reverse-mode automatic differentiation because of the `while_loop`, but it turns out we wouldn't want to anyway: instead of differentiating through the implementation of `fixed_point` and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas's Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we're about to use. In essence, we linearize at the solution and solve those linear equations iteratively to compute the derivatives we want.\n", "\n", "Consider again the equation $x = f(a, x)$ and the function $x^*$. We want to evaluate vector-Jacobian products like $v^\\mathsf{T} \\mapsto v^\\mathsf{T} \\partial x^*(a_0)$.\n", "\n", "At least in an open neighborhood around the point $a_0$ at which we want to differentiate, let's assume that the equation $x^*(a) = f(a, x^*(a))$ holds for all $a$. Since the two sides are equal as functions of $a$, their derivatives must be equal as well, so let's differentiate both sides:\n", "\n", "$\\qquad \\partial x^*(a) = \\partial_0 f(a, x^*(a)) + \\partial_1 f(a, x^*(a)) \\partial x^*(a)$.\n", "\n", "Setting $A = \\partial_1 f(a_0, x^*(a_0))$ and $B = \\partial_0 f(a_0, x^*(a_0))$, we can write the quantity we're after more simply as\n", "\n", "$\\qquad \\partial x^*(a_0) = B + A \\partial x^*(a_0)$,\n", "\n", "or, by rearranging,\n", "\n", "$\\qquad \\partial x^*(a_0) = (I - A)^{-1} B$.\n", "\n", "That means we can evaluate vector-Jacobian products like\n", "\n", "$\\qquad v^\\mathsf{T} \\partial x^*(a_0) = v^\\mathsf{T} (I - A)^{-1} B = w^\\mathsf{T} B$,\n", "\n", "where $w^\\mathsf{T} = v^\\mathsf{T} (I - A)^{-1}$, or equivalently $w^\\mathsf{T} = v^\\mathsf{T} + w^\\mathsf{T} A$, or equivalently $w^\\mathsf{T}$ is the fixed point of the map $u^\\mathsf{T} \\mapsto v^\\mathsf{T} + u^\\mathsf{T} A$. That last characterization gives us a way to write the VJP for `fixed_point` in terms of a call to `fixed_point`! Moreover, after expanding $A$ and $B$ back out, we can see we need only to evaluate VJPs of $f$ at $(a_0, x^*(a_0))$.\n", "\n", "Here's the upshot:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "g4jo-xlvdiym" }, "outputs": [], "source": [ "from jax import vjp\n", "\n", "@partial(custom_vjp, nondiff_argnums=(0,))\n", "def fixed_point(f, a, x_guess):\n", " def cond_fun(carry):\n", " x_prev, x = carry\n", " return jnp.abs(x_prev - x) > 1e-6\n", "\n", " def body_fun(carry):\n", " _, x = carry\n", " return x, f(a, x)\n", "\n", " _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))\n", " return x_star\n", "\n", "def fixed_point_fwd(f, a, x_init):\n", " x_star = fixed_point(f, a, x_init)\n", " return x_star, (a, x_star)\n", "\n", "def fixed_point_rev(f, res, x_star_bar):\n", " a, x_star = res\n", " _, vjp_a = vjp(lambda a: f(a, x_star), a)\n", " a_bar, = vjp_a(fixed_point(partial(rev_iter, f),\n", " (a, x_star, x_star_bar),\n", " x_star_bar))\n", " return a_bar, jnp.zeros_like(x_star)\n", " \n", "def rev_iter(f, packed, u):\n", " a, x_star, x_star_bar = packed\n", " _, vjp_x = vjp(lambda x: f(a, x), x_star)\n", " return x_star_bar + vjp_x(u)[0]\n", "\n", "fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "iKzfT6d_mEoB", "outputId": "5d04c4a0-61dd-42de-ffa4-101b71d15a57" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.4142135\n" ] } ], "source": [ "print(newton_sqrt(2.))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "Hmcpjr6gmtkO", "outputId": "9c4a406c-0144-4d5f-e789-a7a4c850a3cc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.35355335\n", "-0.088388346\n" ] } ], "source": [ "print(grad(newton_sqrt)(2.))\n", "print(grad(grad(newton_sqrt))(2.))" ] }, { "cell_type": "markdown", "metadata": { "id": "DvVmlaPD7W-4" }, "source": [ "We can check our answers by differentiating `jnp.sqrt`, which uses a totally different implementation:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "jj_JnI9Pm4jg", "outputId": "6eb3e158-209b-41f2-865c-376a1d07624b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.35355338\n", "-0.08838835\n" ] } ], "source": [ "print(grad(jnp.sqrt)(2.))\n", "print(grad(grad(jnp.sqrt))(2.))" ] }, { "cell_type": "markdown", "metadata": { "id": "HowvqayEuy-H" }, "source": [ "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for deriviatives in closed-over variables with custom root-finding functions." ] }, { "cell_type": "markdown", "metadata": { "id": "Dr0aNkBslfQf" }, "source": [ "## Basic usage of `jax.custom_jvp` and `jax.custom_vjp` APIs" ] }, { "cell_type": "markdown", "metadata": { "id": "MojTOg4tmQNT" }, "source": [ "### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules\n", "\n", "Here's a canonical basic example of using `jax.custom_jvp`, where the comments use\n", "[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "nVkhbIFAOGZk" }, "outputs": [], "source": [ "from jax import custom_jvp\n", "import jax.numpy as jnp\n", "\n", "# f :: a -> b\n", "@custom_jvp\n", "def f(x):\n", " return jnp.sin(x)\n", "\n", "# f_jvp :: (a, T a) -> (b, T b)\n", "def f_jvp(primals, tangents):\n", " x, = primals\n", " t, = tangents\n", " return f(x), jnp.cos(x) * t\n", "\n", "f.defjvp(f_jvp)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "fxhlECvW7Krj", "outputId": "30dc5e8b-d157-4ae2-cd17-145d4e1ba47b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.14112\n", "0.14112\n", "-0.9899925\n" ] } ], "source": [ "from jax import jvp\n", "\n", "print(f(3.))\n", "\n", "y, y_dot = jvp(f, (3.,), (1.,))\n", "print(y)\n", "print(y_dot)" ] }, { "cell_type": "markdown", "metadata": { "id": "JaoQVRzSQ9Qd" }, "source": [ "In words, we start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it a JVP rule function `f_jvp` that takes a pair of inputs representing the primal inputs of type `a` and the corresponding tangent inputs of type `T a`, and produces a pair of outputs representing the primal outputs of type `b` and tangent outputs of type `T b`. The tangent outputs should be a linear function of the tangent inputs." ] }, { "cell_type": "markdown", "metadata": { "id": "1xGky7yMOavq" }, "source": [ "You can also use `f.defjvp` as a decorator, as in\n", "\n", "```python\n", "@custom_jvp\n", "def f(x):\n", " ...\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " ...\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "e9R-ppvdQIOC" }, "source": [ "Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on `f`. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "hl9Io86pQD6s", "outputId": "a9ef39aa-4df0-459f-ee1d-64b648cabcc4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-0.9899925\n", "-0.14112\n" ] } ], "source": [ "from jax import grad\n", "\n", "print(grad(f)(3.))\n", "print(grad(grad(f))(3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "MRlKe5D90svj" }, "source": [ "For automatic transposition to work, the JVP rule's output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised." ] }, { "cell_type": "markdown", "metadata": { "id": "GRu-0yg96lXE" }, "source": [ "Multiple arguments work like this:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "JFLXlXuq6pRf" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x, y):\n", " return x ** 2 * y\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " x, y = primals\n", " x_dot, y_dot = tangents\n", " primal_out = f(x, y)\n", " tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot\n", " return primal_out, tangent_out" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "QpKwA0oA8DfE", "outputId": "80855f56-04a5-4179-fd8b-199ea7eba476" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "12.0\n" ] } ], "source": [ "print(grad(f)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "YPsPS3rdaGo2" }, "source": [ "The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "id": "CsQIUhUkajua" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x):\n", " return jnp.sin(x)\n", "\n", "f.defjvps(lambda t, ans, x: jnp.cos(x) * t)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "id": "zfSgXrPEap-i", "outputId": "bf552090-a60d-4c2a-fc91-603396df94cd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-0.9899925\n" ] } ], "source": [ "print(grad(f)(3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "iYUCLJghbPiP" }, "source": [ "Here's a `defjvps` example with multiple arguments:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "Vx4Jv9s9bCi1" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x, y):\n", " return x ** 2 * y\n", "\n", "f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,\n", " lambda y_dot, primal_out, x, y: x ** 2 * y_dot)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "id": "o9ezUYsjbbvC", "outputId": "f60f4941-d5e3-49c3-920f-76fd92414697" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "12.0\n", "12.0\n", "4.0\n" ] } ], "source": [ "print(grad(f)(2., 3.))\n", "print(grad(f, 0)(2., 3.)) # same as above\n", "print(grad(f, 1)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "nuIUkaxibVfD" }, "source": [ "As a shorthand, with `defjvps` you can pass a `None` value to indicate that the JVP for a particular argument is zero:" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "id": "z4z3esdZbTzQ" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x, y):\n", " return x ** 2 * y\n", "\n", "f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,\n", " None)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "id": "jOtQfp-5btSo", "outputId": "b60aa797-4c1e-4421-826d-691ba418bc1d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "12.0\n", "12.0\n", "0.0\n" ] } ], "source": [ "print(grad(f)(2., 3.))\n", "print(grad(f, 0)(2., 3.)) # same as above\n", "print(grad(f, 1)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "kZ0yc-Ihoezk" }, "source": [ "Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism." ] }, { "cell_type": "markdown", "metadata": { "id": "3FGwfT67PDs9" }, "source": [ "When you're not performing differentiation, the function `f` is called just as if it weren't decorated by `jax.custom_jvp`:" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "id": "b-tB3xCHPRFt" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x):\n", " print('called f!') # a harmless side-effect\n", " return jnp.sin(x)\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " print('called f_jvp!') # a harmless side-effect\n", " x, = primals\n", " t, = tangents\n", " return f(x), jnp.cos(x) * t" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "id": "xAlRea95PjA5", "outputId": "10b4db9e-3192-415e-ac1c-0dc57c7dc086" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f!\n", "0.14112\n" ] } ], "source": [ "from jax import vmap, jit\n", "\n", "print(f(3.))" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "id": "dyD2ow4NmpI-", "outputId": "1d66b67f-c1b4-4a9d-d6ed-12d88767842c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f!\n", "[0. 0.841471 0.9092974]\n", "called f!\n", "0.14112\n" ] } ], "source": [ "print(vmap(f)(jnp.arange(3.)))\n", "print(jit(f)(3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "EzB75KZ5Pz7m" }, "source": [ "The custom JVP rule is invoked during differentiation, whether forward or reverse:" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "id": "hKF0xyAxPyLZ", "outputId": "214cc5a7-a992-41c8-aa01-8ea4b2b3b4d6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f_jvp!\n", "called f!\n", "-0.9899925\n" ] } ], "source": [ "y, y_dot = jvp(f, (3.,), (1.,))\n", "print(y_dot)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "id": "Z1KaEgA58MEG", "outputId": "86263d76-5a98-4d96-f5c2-9146bcf1b6fd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f_jvp!\n", "called f!\n", "-0.9899925\n" ] } ], "source": [ "print(grad(f)(3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "o8JFxk3lQhOs" }, "source": [ "Notice that `f_jvp` calls `f` to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original `f` to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of `f` in our rule _and also_ have the rule apply in all orders of higher-order differentiation.)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "id": "B6PLJooTQgVp", "outputId": "0d7ac628-656e-4b67-d285-f810155b6b9c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f_jvp!\n", "called f_jvp!\n", "called f!\n" ] }, { "data": { "text/plain": [ "DeviceArray(-0.14112, dtype=float32)" ] }, "execution_count": 50, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "grad(grad(f))(3.)" ] }, { "cell_type": "markdown", "metadata": { "id": "XNxAmFSsaaro" }, "source": [ "You can use Python control flow with `jax.custom_jvp`:" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "id": "kkXlSJL6adU2" }, "outputs": [], "source": [ "@custom_jvp\n", "def f(x):\n", " if x > 0:\n", " return jnp.sin(x)\n", " else:\n", " return jnp.cos(x)\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " x, = primals\n", " x_dot, = tangents\n", " ans = f(x)\n", " if x > 0:\n", " return ans, 2 * x_dot\n", " else:\n", " return ans, 3 * x_dot" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "id": "QCHmJ56Na2G3", "outputId": "1772d3b4-44ef-4745-edd3-553c6312c553" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.0\n", "3.0\n" ] } ], "source": [ "print(grad(f)(1.))\n", "print(grad(f)(-1.))" ] }, { "cell_type": "markdown", "metadata": { "id": "9cVdgR7ilt8l" }, "source": [ "### Use `jax.custom_vjp` to define custom reverse-mode-only rules\n", "\n", "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`:" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "id": "zAZk1n3dUw76" }, "outputs": [], "source": [ "from jax import custom_vjp\n", "import jax.numpy as jnp\n", "\n", "# f :: a -> b\n", "@custom_vjp\n", "def f(x):\n", " return jnp.sin(x)\n", "\n", "# f_fwd :: a -> (b, c)\n", "def f_fwd(x):\n", " return f(x), jnp.cos(x)\n", "\n", "# f_bwd :: (c, CT b) -> CT a\n", "def f_bwd(cos_x, y_bar):\n", " return (cos_x * y_bar,)\n", "\n", "f.defvjp(f_fwd, f_bwd)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "id": "E8W-H2S0Ngdr", "outputId": "cd0dc221-e779-436d-f3b4-21e799f40620" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.14112\n", "-0.9899925\n" ] } ], "source": [ "from jax import grad\n", "\n", "print(f(3.))\n", "print(grad(f)(3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "yLING7qEVGGN" }, "source": [ "In words, we again start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it two functions, `f_fwd` and `f_bwd`, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively.\n", "\n", "The function `f_fwd` describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function `f`, in that it takes a primal input of type `a`. But as output it produces a pair, where the first element is the primal output `b` and the second element is any \"residual\" data of type `c` to be stored for use by the backward pass. (This second output is analogous to [PyTorch's save_for_backward mechanism](https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html).)\n", "\n", "The function `f_bwd` describes the backward pass. It takes two inputs, where the first is the residual data of type `c` produced by `f_fwd` and the second is the output cotangents of type `CT b` corresponding to the output of the primal function. It produces an output of type `CT a` representing the cotangents corresponding to the input of the primal function. In particular, the output of `f_bwd` must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function." ] }, { "cell_type": "markdown", "metadata": { "id": "d1b5v67Oncfz" }, "source": [ "So multiple arguments work like this:" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "id": "IhMb64gkngAt" }, "outputs": [], "source": [ "from jax import custom_vjp\n", "\n", "@custom_vjp\n", "def f(x, y):\n", " return jnp.sin(x) * y\n", "\n", "def f_fwd(x, y):\n", " return f(x, y), (jnp.cos(x), jnp.sin(x), y)\n", "\n", "def f_bwd(res, g):\n", " cos_x, sin_x, y = res\n", " return (cos_x * g * y, -sin_x * g)\n", "\n", "f.defvjp(f_fwd, f_bwd)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "id": "EnRtIhhLnkry", "outputId": "e03907ec-463a-4f3c-ae8e-feecb4394b2b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-1.2484405\n" ] } ], "source": [ "print(grad(f)(2., 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "GwC26P9kn8qw" }, "source": [ "Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism." ] }, { "cell_type": "markdown", "metadata": { "id": "XfH-ae8bYt6-" }, "source": [ "As with `jax.custom_jvp`, the custom VJP rule comprised by `f_fwd` and `f_bwd` is not invoked if differentiation is not applied. If function is evaluated, or transformed with `jit`, `vmap`, or other non-differentiation transformations, then only `f` is called." ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "id": "s-_Dbqi-N5Ij" }, "outputs": [], "source": [ "@custom_vjp\n", "def f(x):\n", " print(\"called f!\")\n", " return jnp.sin(x)\n", "\n", "def f_fwd(x):\n", " print(\"called f_fwd!\")\n", " return f(x), jnp.cos(x)\n", "\n", "def f_bwd(cos_x, y_bar):\n", " print(\"called f_bwd!\")\n", " return (cos_x * y_bar,)\n", "\n", "f.defvjp(f_fwd, f_bwd)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "id": "r0aZ79OmOAR5", "outputId": "9cf16d9e-ca96-4987-e01a-dc0e22405576" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f!\n", "0.14112\n" ] } ], "source": [ "print(f(3.))" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "id": "7ToB9BYlm6uN", "outputId": "aa9f3e3f-e6c3-4ee4-b87a-4526074f43aa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f_fwd!\n", "called f!\n", "called f_bwd!\n", "-0.9899925\n" ] } ], "source": [ "print(grad(f)(3.))" ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "id": "s1Pn_qCIODcF", "outputId": "423d34e0-35b8-4b57-e89d-f70f20e28ea9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f_fwd!\n", "called f!\n", "0.14112\n" ] } ], "source": [ "from jax import vjp\n", "\n", "y, f_vjp = vjp(f, 3.)\n", "print(y)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "id": "dvgQtDHaOHuo", "outputId": "d92649c5-0aab-49a9-9158-f7ddc5fccb9b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f_bwd!\n", "(DeviceArray(-0.9899925, dtype=float32),)\n" ] } ], "source": [ "print(f_vjp(1.))" ] }, { "cell_type": "markdown", "metadata": { "id": "qFIIpkFcZCNP" }, "source": [ "**Forward-mode autodiff cannot be used on the** `jax.custom_vjp` **function** and will raise an error:" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "id": "3RGQRbI_OSEX", "outputId": "6385a024-7a10-445a-8380-b2eef722e597" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "called f_fwd!\n", "called f!\n", "ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.\n" ] } ], "source": [ "from jax import jvp\n", "\n", "try:\n", " jvp(f, (3.,), (1.,))\n", "except TypeError as e:\n", " print('ERROR! {}'.format(e))" ] }, { "cell_type": "markdown", "metadata": { "id": "u04I9j2dntAU" }, "source": [ "If you want to use both forward- and reverse-mode, use `jax.custom_jvp` instead." ] }, { "cell_type": "markdown", "metadata": { "id": "YN97y7LEZbWV" }, "source": [ "We can use `jax.custom_vjp` together with `pdb` to insert a debugger trace in the backward pass:" ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "id": "-DvRKsHPZk_g" }, "outputs": [], "source": [ "import pdb\n", "\n", "@custom_vjp\n", "def debug(x):\n", " return x # acts like identity\n", "\n", "def debug_fwd(x):\n", " return x, x\n", "\n", "def debug_bwd(x, g):\n", " import pdb; pdb.set_trace()\n", " return g\n", "\n", "debug.defvjp(debug_fwd, debug_bwd)" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "id": "49GdkP4pZ2IV" }, "outputs": [], "source": [ "def foo(x):\n", " y = x ** 2\n", " y = debug(y) # insert pdb in corresponding backward pass step\n", " return jnp.sin(y)" ] }, { "cell_type": "markdown", "metadata": { "id": "sGLnRcPwaKoX" }, "source": [ "```python\n", "jax.grad(foo)(3.)\n", "\n", "> (12)debug_bwd()\n", "-> return g\n", "(Pdb) p x\n", "DeviceArray(9., dtype=float32)\n", "(Pdb) p g\n", "DeviceArray(-0.91113025, dtype=float32)\n", "(Pdb) q\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "DaTfAJLAl1Lb" }, "source": [ "## More features and details" ] }, { "cell_type": "markdown", "metadata": { "id": "LQF_UDApl_UV" }, "source": [ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n", "\n", "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", "\n", "Here's a contrived example with `jax.custom_jvp`:" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "id": "6sDLZ3dAn3P2" }, "outputs": [], "source": [ "from collections import namedtuple\n", "Point = namedtuple(\"Point\", [\"x\", \"y\"])\n", "\n", "@custom_jvp\n", "def f(pt):\n", " x, y = pt.x, pt.y\n", " return {'a': x ** 2,\n", " 'b': (jnp.sin(x), jnp.cos(y))}\n", "\n", "@f.defjvp\n", "def f_jvp(primals, tangents):\n", " pt, = primals\n", " pt_dot, = tangents\n", " ans = f(pt)\n", " ans_dot = {'a': 2 * pt.x * pt_dot.x,\n", " 'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}\n", " return ans, ans_dot\n", "\n", "def fun(pt):\n", " dct = f(pt)\n", " return dct['a'] + dct['b'][0]" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "id": "My8pbOlPppJj", "outputId": "04cc1129-d0fb-4018-bec1-2ccf8b7906e3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n" ] } ], "source": [ "pt = Point(1., 2.)\n", "\n", "print(f(pt))" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "id": "a9qyiCAhqLd3", "outputId": "08bd0615-7c35-44ff-f90b-c175618c2c40" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Point(x=DeviceArray(2.5403023, dtype=float32), y=array(0., dtype=float32))\n" ] } ], "source": [ "print(grad(fun)(pt))" ] }, { "cell_type": "markdown", "metadata": { "id": "BWLN9tu4qWQd" }, "source": [ "And an analogous contrived example with `jax.custom_vjp`:" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "id": "QkdbwGkJqS3J" }, "outputs": [], "source": [ "@custom_vjp\n", "def f(pt):\n", " x, y = pt.x, pt.y\n", " return {'a': x ** 2,\n", " 'b': (jnp.sin(x), jnp.cos(y))}\n", "\n", "def f_fwd(pt):\n", " return f(pt), pt\n", "\n", "def f_bwd(pt, g):\n", " a_bar, (b0_bar, b1_bar) = g['a'], g['b']\n", " x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar\n", " y_bar = -jnp.sin(pt.y) * b1_bar\n", " return (Point(x_bar, y_bar),)\n", "\n", "f.defvjp(f_fwd, f_bwd)\n", "\n", "def fun(pt):\n", " dct = f(pt)\n", " return dct['a'] + dct['b'][0]" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "id": "3onW7t6nrJ4E", "outputId": "ac455ab0-cac0-41fc-aea3-034931316053" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n" ] } ], "source": [ "pt = Point(1., 2.)\n", "\n", "print(f(pt))" ] }, { "cell_type": "code", "execution_count": 70, "metadata": { "id": "ryyeKIXtrNpd", "outputId": "1780f738-ffd8-4ed7-ffbe-71d84bd62709" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Point(x=DeviceArray(2.5403023, dtype=float32), y=DeviceArray(-0., dtype=float32))\n" ] } ], "source": [ "print(grad(fun)(pt))" ] }, { "cell_type": "markdown", "metadata": { "id": "JKTNivxbmKWO" }, "source": [ "### Handling non-differentiable arguments" ] }, { "cell_type": "markdown", "metadata": { "id": "7g9sXSp_uc36" }, "source": [ "Some use cases, like the final example problem, call for non-differentiable arguments like function-valued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of `fixed_point`, the function argument `f` was such a non-differentiable argument. A similar situation arises with `jax.experimental.odeint`." ] }, { "cell_type": "markdown", "metadata": { "id": "9yNIOzyBCvE5" }, "source": [ "#### `jax.custom_jvp` with `nondiff_argnums`\n", "\n", "Use the optional `nondiff_argnums` parameter to `jax.custom_jvp` to indicate arguments like these. Here's an example with `jax.custom_jvp`:" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "id": "b3YMxxTBvy0I" }, "outputs": [], "source": [ "from functools import partial\n", "\n", "@partial(custom_jvp, nondiff_argnums=(0,))\n", "def app(f, x):\n", " return f(x)\n", "\n", "@app.defjvp\n", "def app_jvp(f, primals, tangents):\n", " x, = primals\n", " x_dot, = tangents\n", " return f(x), 2. * x_dot" ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "id": "5W-yEw9IB34S", "outputId": "a2c1444a-9cc7-43ee-cb52-6c5d1cec02f1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "27.0\n" ] } ], "source": [ "print(app(lambda x: x ** 3, 3.))" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "id": "zbVIlOmqB7_O", "outputId": "a0174f54-89b0-4957-9362-c05af922f974" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.0\n" ] } ], "source": [ "print(grad(app, 1)(lambda x: x ** 3, 3.))" ] }, { "cell_type": "markdown", "metadata": { "id": "-b_B_4WaBI2D" }, "source": [ "Notice the gotcha here: no matter where in the argument list these parameters appear, they're placed at the *start* of the signature of the corresponding JVP rule. Here's another example:" ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "id": "9hokWmyHBgKK" }, "outputs": [], "source": [ "@partial(custom_jvp, nondiff_argnums=(0, 2))\n", "def app2(f, x, g):\n", " return f(g((x)))\n", "\n", "@app2.defjvp\n", "def app2_jvp(f, g, primals, tangents):\n", " x, = primals\n", " x_dot, = tangents\n", " return f(g(x)), 3. * x_dot" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "id": "J7GsvJTgCfS0", "outputId": "43dd6a02-2e4e-449e-924a-d1a03fe622fe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3375.0\n" ] } ], "source": [ "print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))" ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "id": "kPP8Jt1CCb1X", "outputId": "6eff9aae-8d6e-4998-92ed-56272c32d6e8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.0\n" ] } ], "source": [ "print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))" ] }, { "cell_type": "markdown", "metadata": { "id": "ECbalHIkC4ts" }, "source": [ "#### `jax.custom_vjp` with `nondiff_argnums`" ] }, { "cell_type": "markdown", "metadata": { "id": "0u0jn4aWC8k1" }, "source": [ "A similar option exists for `jax.custom_vjp`, and, similarly, the convention is that the non-differentiable arguments are passed as the first arguments to the `_bwd` rule, no matter where they appear in the signature of the original function. The signature of the `_fwd` rule remains unchanged - it is the same as the signature of the primal function. Here's an example:" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "id": "yCdu-_9GClWs" }, "outputs": [], "source": [ "@partial(custom_vjp, nondiff_argnums=(0,))\n", "def app(f, x):\n", " return f(x)\n", "\n", "def app_fwd(f, x):\n", " return f(x), x\n", "\n", "def app_bwd(f, x, g):\n", " return (5 * g,)\n", "\n", "app.defvjp(app_fwd, app_bwd)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": { "id": "qSgcWa1eDj4r", "outputId": "43939686-f857-47ea-9f85-53f440ef12ee" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "16.0\n" ] } ], "source": [ "print(app(lambda x: x ** 2, 4.))" ] }, { "cell_type": "code", "execution_count": 79, "metadata": { "id": "tccagflcDmaz", "outputId": "c75ca70b-2431-493b-e335-4f4d340902f1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5.0\n" ] } ], "source": [ "print(grad(app, 1)(lambda x: x ** 2, 4.))" ] }, { "cell_type": "markdown", "metadata": { "id": "BTEnNTk5D0sM" }, "source": [ "See `fixed_point` above for another usage example.\n", "\n", "**You don't need to use** `nondiff_argnums` **with array-valued arguments**, for example ones with integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "Custom derivative rules for Python code.ipynb", "provenance": [], "toc_visible": true }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }