Autobatching log-densities example¶
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
[1]:
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
import numpy as np
import scipy as sp
Generate a fake binary classification dataset¶
[2]:
np.random.seed(10009)
num_features = 10
num_points = 100
true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32)
[3]:
y
[3]:
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
Write the log-joint function for the model¶
We’ll write a non-batched version, a manually batched version, and an autobatched version.
Non-batched¶
[4]:
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
return result
[5]:
log_joint(np.random.randn(num_features))
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/stable/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
[5]:
DeviceArray(-213.23558, dtype=float32)
[6]:
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
print("Caught expected exception " + str(e))
Caught expected exception Incompatible shapes for broadcasting: ((100, 10), (1, 100))
Manually batched¶
[7]:
def batched_log_joint(beta):
result = 0.
# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
# or setting it incorrectly yields an error; at worst, it silently changes the
# semantics of the model.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
axis=-1)
# Note the multiple transposes. Getting this right is not rocket science,
# but it's also not totally mindless. (I didn't get it right on the first
# try.)
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
axis=-1)
return result
[8]:
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta)
[8]:
DeviceArray([-147.84032, -207.02205, -109.26076, -243.80833, -163.02908,
-143.84848, -160.28772, -113.7717 , -126.60544, -190.81989], dtype=float32)
Autobatched with vmap¶
It just works.
[9]:
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
[9]:
DeviceArray([-147.84032, -207.02205, -109.26076, -243.80833, -163.02908,
-143.84848, -160.28772, -113.7717 , -126.60544, -190.81989], dtype=float32)
Self-contained variational inference example¶
A little code is copied from above.
Set up the (batched) log-joint function¶
[10]:
@jax.jit
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `jnp.sum`.
result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
return result
batched_log_joint = jax.jit(jax.vmap(log_joint))
Define the ELBO and its gradient¶
[11]:
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
elbo = jax.jit(elbo, static_argnums=(1, 2))
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
Optimize the ELBO using SGD¶
[12]:
def normal_sample(key, shape):
"""Convenience function for quasi-stateful RNG."""
new_key, sub_key = random.split(key)
return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.PRNGKey(10003)
beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
key, epsilon = normal_sample(key, epsilon_shape)
elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
beta_loc, beta_log_scale, epsilon)
beta_loc += step_size * beta_loc_grad
beta_log_scale += step_size * beta_log_scale_grad
if i % 10 == 0:
print('{}\t{}'.format(i, elbo_val))
0 -180.8538818359375
10 -113.06045532226562
20 -102.73725891113281
30 -99.787353515625
40 -98.90898132324219
50 -98.29745483398438
60 -98.18630981445312
70 -97.5797348022461
80 -97.28600311279297
90 -97.469970703125
100 -97.4771728515625
110 -97.58067321777344
120 -97.49435424804688
130 -97.50271606445312
140 -96.86395263671875
150 -97.44197082519531
160 -97.06939697265625
170 -96.84028625488281
180 -97.21336364746094
190 -97.56502532958984
200 -97.26398468017578
210 -97.11979675292969
220 -97.39593505859375
230 -97.16830444335938
240 -97.118408203125
250 -97.24345397949219
260 -97.2978744506836
270 -96.69285583496094
280 -96.9643783569336
290 -97.30055236816406
300 -96.63594055175781
310 -97.03518676757812
320 -97.52909851074219
330 -97.28812408447266
340 -97.0732192993164
350 -97.15620422363281
360 -97.25882720947266
370 -97.19515228271484
380 -97.13092041015625
390 -97.11727905273438
400 -96.93873596191406
410 -97.26676940917969
420 -97.35324096679688
430 -97.21007537841797
440 -97.28434753417969
450 -97.16310119628906
460 -97.2612533569336
470 -97.21343994140625
480 -97.23997497558594
490 -97.14913177490234
500 -97.23528289794922
510 -96.9342041015625
520 -97.21209716796875
530 -96.82577514648438
540 -97.01286315917969
550 -96.94176483154297
560 -97.16522216796875
570 -97.29165649414062
580 -97.42939758300781
590 -97.24371337890625
600 -97.15219116210938
610 -97.49844360351562
620 -96.99070739746094
630 -96.88957977294922
640 -96.89970397949219
650 -97.13794708251953
660 -97.43707275390625
670 -96.99235534667969
680 -97.15623474121094
690 -97.18690490722656
700 -97.11160278320312
710 -97.78105163574219
720 -97.23226165771484
730 -97.16206359863281
740 -96.99581909179688
750 -96.66722869873047
760 -97.16796112060547
770 -97.51435089111328
780 -97.28901672363281
790 -96.91226196289062
800 -97.1709976196289
810 -97.29047393798828
820 -97.16242980957031
830 -97.1910629272461
840 -97.56382751464844
850 -97.00193786621094
860 -96.86555480957031
870 -96.76337432861328
880 -96.83661651611328
890 -97.12179565429688
900 -97.09554290771484
910 -97.0682373046875
920 -97.11947631835938
930 -96.8792953491211
940 -97.45625305175781
950 -96.69280242919922
960 -97.29376220703125
970 -97.3353042602539
980 -97.34962463378906
990 -97.09674835205078
Display the results¶
Coverage isn’t quite as good as we might like, but it’s not bad, and nobody said variational inference was exact.
[13]:
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
[13]:
<matplotlib.legend.Legend at 0x7f069c2b6d68>

[ ]: