Interactive online version: Open In Colab

Autobatching log-densities example

This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.

Inspired by a notebook by @davmre.

[1]:
import functools
import itertools
import re
import sys
import time

from matplotlib.pyplot import *

import jax

from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random

import numpy as np
import scipy as sp

Generate a fake binary classification dataset

[2]:
np.random.seed(10009)

num_features = 10
num_points = 100

true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
[3]:
y
[3]:
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)

Write the log-joint function for the model

We’ll write a non-batched version, a manually batched version, and an autobatched version.

Non-batched

[4]:
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result
[5]:
log_joint(np.random.randn(num_features))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[5]:
DeviceArray(-213.23558, dtype=float32)
[6]:
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
  batch_size = 10
  batched_test_beta = np.random.randn(batch_size, num_features)

  log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
  print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: ((100, 10), (1, 100))

Manually batched

[7]:
def batched_log_joint(beta):
    result = 0.
    # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
    # or setting it incorrectly yields an error; at worst, it silently changes the
    # semantics of the model.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
                           axis=-1)
    # Note the multiple transposes. Getting this right is not rocket science,
    # but it's also not totally mindless. (I didn't get it right on the first
    # try.)
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
                           axis=-1)
    return result
[8]:
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)

batched_log_joint(batched_test_beta)
[8]:
DeviceArray([-147.84032, -207.02205, -109.26076, -243.80833, -163.02908,
             -143.84848, -160.28772, -113.7717 , -126.60544, -190.81989],            dtype=float32)

Autobatched with vmap

It just works.

[9]:
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
[9]:
DeviceArray([-147.84032, -207.02205, -109.26076, -243.80833, -163.02908,
             -143.84848, -160.28772, -113.7717 , -126.60544, -190.81989],            dtype=float32)

Self-contained variational inference example

A little code is copied from above.

Set up the (batched) log-joint function

[10]:
@jax.jit
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result

batched_log_joint = jax.jit(jax.vmap(log_joint))

Define the ELBO and its gradient

[11]:
def elbo(beta_loc, beta_log_scale, epsilon):
    beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
    return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))

elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))

Optimize the ELBO using SGD

[12]:
def normal_sample(key, shape):
    """Convenience function for quasi-stateful RNG."""
    new_key, sub_key = random.split(key)
    return new_key, random.normal(sub_key, shape)

normal_sample = jax.jit(normal_sample, static_argnums=(1,))

key = random.PRNGKey(10003)

beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)

step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
    key, epsilon = normal_sample(key, epsilon_shape)
    elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
        beta_loc, beta_log_scale, epsilon)
    beta_loc += step_size * beta_loc_grad
    beta_log_scale += step_size * beta_log_scale_grad
    if i % 10 == 0:
        print('{}\t{}'.format(i, elbo_val))
0       -180.8538818359375
10      -113.06045532226562
20      -102.73725891113281
30      -99.787353515625
40      -98.90898132324219
50      -98.29745483398438
60      -98.18630981445312
70      -97.5797348022461
80      -97.28600311279297
90      -97.469970703125
100     -97.4771728515625
110     -97.58067321777344
120     -97.49435424804688
130     -97.50271606445312
140     -96.86395263671875
150     -97.44197082519531
160     -97.06939697265625
170     -96.84028625488281
180     -97.21336364746094
190     -97.56502532958984
200     -97.26398468017578
210     -97.11979675292969
220     -97.39593505859375
230     -97.16830444335938
240     -97.118408203125
250     -97.24345397949219
260     -97.2978744506836
270     -96.69285583496094
280     -96.9643783569336
290     -97.30055236816406
300     -96.63594055175781
310     -97.03518676757812
320     -97.52909851074219
330     -97.28812408447266
340     -97.0732192993164
350     -97.15620422363281
360     -97.25882720947266
370     -97.19515228271484
380     -97.13092041015625
390     -97.11727905273438
400     -96.93873596191406
410     -97.26676940917969
420     -97.35324096679688
430     -97.21007537841797
440     -97.28434753417969
450     -97.16310119628906
460     -97.2612533569336
470     -97.21343994140625
480     -97.23997497558594
490     -97.14913177490234
500     -97.23528289794922
510     -96.9342041015625
520     -97.21209716796875
530     -96.82577514648438
540     -97.01286315917969
550     -96.94176483154297
560     -97.16522216796875
570     -97.29165649414062
580     -97.42939758300781
590     -97.24371337890625
600     -97.15219116210938
610     -97.49844360351562
620     -96.99070739746094
630     -96.88957977294922
640     -96.89970397949219
650     -97.13794708251953
660     -97.43707275390625
670     -96.99235534667969
680     -97.15623474121094
690     -97.18690490722656
700     -97.11160278320312
710     -97.78105163574219
720     -97.23226165771484
730     -97.16206359863281
740     -96.99581909179688
750     -96.66722869873047
760     -97.16796112060547
770     -97.51435089111328
780     -97.28901672363281
790     -96.91226196289062
800     -97.1709976196289
810     -97.29047393798828
820     -97.16242980957031
830     -97.1910629272461
840     -97.56382751464844
850     -97.00193786621094
860     -96.86555480957031
870     -96.76337432861328
880     -96.83661651611328
890     -97.12179565429688
900     -97.09554290771484
910     -97.0682373046875
920     -97.11947631835938
930     -96.8792953491211
940     -97.45625305175781
950     -96.69280242919922
960     -97.29376220703125
970     -97.3353042602539
980     -97.34962463378906
990     -97.09674835205078

Display the results

Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.

[13]:
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
[13]:
<matplotlib.legend.Legend at 0x7fe7687b2a10>
../_images/notebooks_vmapped_log_probs_23_1.png
[ ]: