Generative Modeling by Estimating Gradients of Data Distribution in JAX

Open in Colab

In this notebook we’ll implement Generative Modeling by Estimating Gradients of the Data Distribution [arxiv].

The paper builds on a technique called Score Matching to iteratively refine images into samples from the real data. This technique takes advantage of some recent theorems to learn the gradients of data probability distribution. Implementing this in your typical DL framework would become an issue. However, JAX ain’t your typical DL framework. JAX makes implementing complex math no harder than writing it on a piece of paper.

Let’s begin with a simple task: learn to sample from a peculiar probability distribution. Like a swiss roll, for example:

img

well… minus the chocolate. Sorry for that.

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

from sklearn.datasets import make_swiss_roll

def sample_batch(size, noise=1.0):
    x, _= make_swiss_roll(size, noise=noise)
    x = x[:, [0, 2]] / 10.0
    return np.array(x)

plt.scatter(*sample_batch(10**4).T, alpha=0.1)
<matplotlib.collections.PathCollection at 0x7f06ea32f550>
../_images/score_matching_1_1.png

Compute score matching objective

The method we apply here was originally proposed by Hyvarinen et al. (2005). The idea behind score matching is to learn scores: the gradients of \(\log p(x)\) w.r.t. \(x\). When trained this model can “improve” a sample \(x\) by changing it in the direction of highest log-probability. However, training such model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:

\[ L_{mse} = E_{x \sim p(x)} \left\lVert model(x) - \nabla_x \log p(x) \right\lVert_2^2 \]

One can’t minimize this explicitly because the real \(\nabla_x log p(x)\) is usually unknown. However under broad assumptions on p(x) and a sufficiently powerful model, one can say that ... math happens ... and therefore the arg-minimum of \(L_{mse}\) can be found by minimizing a more tractable objective:

\[ L_{matching} = E_{x \sim p(x)} \space tr( \space \mathbf{J}_x [\space model(x) \space]) + \frac12 \left\Vert model(x) \right\lVert_2^2 \]

Here \(tr( \space \mathbf{J}_x [\space model(x) \space])\) is a trace of Jacobian of \(model(x)\) w.r.t. \(x\). Now all it takes is to minimize the second objective with backpropagation… that is, if you can compute jacobians. Thankfully, we have jax!

import jax
import jax.numpy as jnp
from jax.experimental import optimizers
from jax.experimental import stax
from functools import partial

# Set up network to predict scores
net_init, net_apply = stax.serial(
    stax.Dense(128), stax.Softplus,
    stax.Dense(128), stax.Softplus,
    stax.Dense(2),
)

# Create optimizer. Note that both network and optimizer returns pure (stateless) functions
opt_init, opt_update, get_params = optimizers.adam(1e-3)
# v-- jax.jit compiles a function for efficient CPU and GPU execution

@jax.jit
def compute_loss(net_params, inputs):
    # v-- a function that computes jacobian by forward mode differentiation
    jacobian = jax.jacfwd(net_apply, argnums=-1)
    
    # we use jax.vmap to vectorize jacobian function along batch dimension
    batch_jacobian = jax.vmap(partial(jacobian, net_params))(inputs)  # [batch, dim, dim]
    
    trace_jacobian = jnp.trace(batch_jacobian, axis1=1, axis2=2)
    output_norm_sq = jnp.square(net_apply(net_params, inputs)).sum(axis=1)
    
    return jnp.mean(trace_jacobian + 1/2 * output_norm_sq)


@jax.jit
def train_step(step_i, opt_state, batch):
    net_params = get_params(opt_state)
    loss = compute_loss(net_params, batch)
    grads = jax.grad(compute_loss, argnums=0)(net_params, batch)
    return loss, opt_update(step_i, grads, opt_state)

Note: we use jax.jacfwd since the input dimension is only 2

Training loop

from IPython.display import clear_output

out_shape, net_params = net_init(jax.random.PRNGKey(seed=42), input_shape=(-1, 2))
opt_state = opt_init(net_params)

loss_history = []
for i in range(2000):
    x = sample_batch(size=128)
    loss, opt_state = train_step(i, opt_state, x)
    loss_history.append(loss.item())
    
    if i % 200 == 0:
        clear_output(True)
        plt.figure(figsize=[16, 8])
        plt.subplot(1, 2, 1)
        plt.title("mean loss = %.3f" % jnp.mean(jnp.array(loss_history[-32:])))
        plt.scatter(jnp.arange(len(loss_history)), loss_history)
        plt.grid()
        
        plt.subplot(1, 2, 2)
        net_params = get_params(opt_state)
        xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 2.0, 50), jnp.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)
        scores = net_apply(net_params, xx)
        scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)
        scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)

        plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')
        plt.xlim(-1.5, 2.0)
        plt.ylim(-1.5, 2.0)
        plt.show()
../_images/score_matching_8_0.png

Plot gradient directions

Once the model is trained we can use it to predict scores at each point. Since those are gradient vectors, we’ll use Quiver Plot to draw them.

plt.figure(figsize=[16, 16])

net_params = get_params(opt_state)
xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 1.5, 50), jnp.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)
scores = net_apply(net_params, xx)
scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)
scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)

plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')
plt.scatter(*sample_batch(10_000).T, alpha=0.25)
<matplotlib.collections.PathCollection at 0x7f06e86b6f28>
../_images/score_matching_10_1.png

A hot new paper by Song et al. (2019) uses this method to generate images by iterative refinement… Apparently it took DL researchers 14 years to understand the proof :)

Seriously though, this paper takes advantage of two new ideas: sampling with Langevin Dynamics and scaling to high dimensions with Sliced Score Matching. We’ll cover them one at a time.

Sampling with Langevin Dynamics

Once we have \(\nabla_x log p(x)\), we can use it to generate data. One simple thing you can do is a gradient ascent w.r.t image to find a local maximum of p(x): $\(\hat x_{t + 1} := x_t + \epsilon \nabla_{x_t} log p(x_t)\)$

In order to sample \(x \sim p(x)\), one can run a slightly more sophisticated procedure:

\[\hat x_{t+1} := \hat x_t + \frac \epsilon 2 \nabla_{\hat x_t} log p(\hat x_t) + \sqrt \epsilon z_t, \quad z_t \sim N(0, I)\]

Performing this update multiple times in an MCMC fashion is a special case of Langevin Dynamics. Under \(\epsilon \rightarrow 0, t \rightarrow \inf\): \(\hat x_t\) converges to a sample from \(p(x)\). You can find a more detailed explanation and a formal proof in Welling et al. (2011) and further exploration of SGLD in The et al. (2014) and Vollmer et al. (2015).

In practice, we can initialize \(x_0\) from some initial guess (e.g. uniform distribution over data space) and \(\epsilon\) to some positive value. As the sampling progresses, we can anneal \(\epsilon\) it until we are satisfied with the samples. Okay, now let’s go implement that :)

def sample_langevin(x_initial, *, net_params, key, eps=1e-2, eps_decay=0.9, num_steps=15, temperature=1.0):
    """ sample x ~ p(x) by applying approximate Langvenin Dynamics, return a sequence of x_t """
    x_t, x_sequence = x_initial, [x_initial]
    
    for t in range(num_steps):
        key, subkey = jax.random.split(key)
        z_t = jax.random.normal(subkey, shape=x_t.shape)
        x_t = x_t + eps / 2 * net_apply(net_params, x_t) + jnp.sqrt(eps) * temperature * z_t
        x_sequence.append(x_t)
        eps *= eps_decay
        
    return jnp.stack(x_sequence)
plt.figure(figsize=[16, 16])

key = jax.random.PRNGKey(42)
net_params = get_params(opt_state)

for x_initial in jnp.array([[-1.5, -1.5], [0, 0], [1.5, 0]]):
    key, subkey = jax.random.split(key)
    # sample x sequence
    xx = sample_langevin(x_initial, key=subkey, net_params=net_params)
    plt.scatter(xx.T[0], xx.T[1], color="blue")

    # draw arrows for each mcmc step
    deltas = (xx[1:] - xx[:-1])
    deltas = deltas - deltas / jnp.linalg.norm(deltas, keepdims=True, axis=-1) * 0.04
    for i, arrow in enumerate(deltas):
        plt.arrow(xx[i][0], xx[i][1], arrow[0], arrow[1], width=1e-4, head_width=2e-2, color="orange")
    
# plot data points and gradients
plt.plot()
xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 1.5, 50), jnp.linspace(-1.5, 1.5, 50)), axis=-1).reshape(-1, 2)
scores = net_apply(net_params, xx)
scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)
scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)
plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')
plt.scatter(*sample_batch(10_000).T, alpha=0.025)
<matplotlib.collections.PathCollection at 0x7f06ea223278>
../_images/score_matching_14_1.png

Sliced Score Matching

Now the problem with our previous loss function is that the computation of \(tr(\mathbf{J}_x [\space model(x)])\) takes a \(O(N^2 + N)\) time to compute, thus not being suitable for high-dimensional problems. The solution is using jacobian vector products which can be easily computed using forward mode auto-differentiation. This method is called Sliced Score Matching and was proposed by Yang Song et al. (2019).

Our new objective looks like this:

\[E_{\mathbf{v} \sim \mathcal{N}(0, 1)} E_{x \sim p(x)} [ \mathbf{v}^T \mathbf{J}_x[model(x)] \mathbf{v} + \frac{1}{2} (\mathbf{v}^T model(x))^2 ]\]

Jacobian Vector products, by the way, can be easily computed using jax.jvp.

@jax.jit
def compute_ssm_loss(net_params, inputs, key):
    apply = jax.jit(partial(net_apply, net_params))
    batch_dot = partial(jnp.einsum, 'bu,bu->b')
    
    # generate random vectors from N(0, I)
    v = jax.random.normal(key, shape=inputs.shape)

    # predict score and comput jacobian of score times v
    score, jac_v = jax.jvp(apply, [inputs], [v])
    
    return jnp.mean(batch_dot(v, jac_v) + 1/2 * batch_dot(v, score) ** 2)

@jax.jit
def train_step(step_i, opt_state, batch, key):
    # the new compute_loss is random key dependent, thus we need a new train_step function
    net_params = get_params(opt_state)
    loss = compute_ssm_loss(net_params, batch, key)
    grads = jax.grad(compute_ssm_loss, argnums=0)(net_params, batch, key)
    return loss, opt_update(step_i, grads, opt_state)

Note: we compute Jacobian with jax.jacfwd (forward-mode differentiation) because the input dimension of the network is just 2. You can read more about autograd modes in jax documentation and on wikipedia wiki

key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
out_shape, net_params = net_init(subkey, input_shape=(-1, 2))
opt_state = opt_init(net_params)

loss_history = []
for i in range(2_000):
    x = sample_batch(size=128)
    
    key, subkey = jax.random.split(key)
    loss, opt_state = train_step(i, opt_state, x, subkey)
    loss_history.append(loss.item())
    
    if i % 200 == 0:
        clear_output(True)
        plt.figure(figsize=[16, 8])
        plt.subplot(1, 2, 1)
        plt.title("mean loss = %.3f" % jnp.mean(jnp.array(loss_history[-32:])))
        plt.scatter(jnp.arange(len(loss_history)), loss_history)
        plt.grid()
        
        plt.subplot(1, 2, 2)
        net_params = get_params(opt_state)
        xx = jnp.stack(jnp.meshgrid(jnp.linspace(-1.5, 2.0, 50), jnp.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)
        scores = net_apply(net_params, xx)
        scores_norm = jnp.linalg.norm(scores, axis=-1, ord=2, keepdims=True)
        scores_log1p = scores / (scores_norm + 1e-9) * jnp.log1p(scores_norm)

        plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')
        plt.xlim(-1.5, 2.0)
        plt.ylim(-1.5, 2.0)
        plt.show()
../_images/score_matching_19_0.png

Easy? Let’s go deeper!

MNIST 8x8, computing full jacobian would require 64 passes through the network

from sklearn.datasets import load_digits
import numpy as np

X, _ = load_digits(return_X_y=True)

for i in range(5):
    plt.subplot(1, 5, i + 1)
    plt.imshow(X[i].reshape(8, 8), cmap='gray')
    

def sample_batch(size, noise=0.1):
    ix = np.random.randint(0, len(X), size=size)
    return jnp.array(X[ix] / 16 + noise * np.random.randn(size, 64))
../_images/score_matching_21_0.png
# Set up network to predict scores
net_init, net_apply = stax.serial(
    stax.Dense(128), stax.Softplus,
    stax.Dense(128), stax.Softplus,
    stax.Dense(64),
)

# Create optimizer. Note that both network and optimizer returns pure (stateless) functions
opt_init, opt_update, get_params = optimizers.adam(1e-3)

key = jax.random.PRNGKey(seed=42)
key, subkey = jax.random.split(key)
out_shape, net_params = net_init(subkey, input_shape=(-1, 64))
opt_state = opt_init(net_params)

loss_history = []
for i in range(5_000):
    x = sample_batch(size=128)
    key, subkey = jax.random.split(key)
    loss, opt_state = train_step(i, opt_state, x, subkey)
    loss_history.append(loss.item())
    
    if i % 500 == 0:
        clear_output(True)
        plt.title("mean loss = %.3f" % jnp.mean(jnp.array(loss_history[-32:])))
        plt.scatter(jnp.arange(len(loss_history)), loss_history)
        plt.show()
../_images/score_matching_23_0.png
key, subkey = jax.random.split(key)
x = 0.1 * jax.random.uniform(subkey, shape=(64,))

xx = sample_langevin(x, net_params=get_params(opt_state), key=key,
                     eps=0.05, eps_decay=0.98, num_steps=50,
                     temperature=0.02)  # set low temperature to compensate for noise in training data

for t, x_t in enumerate(xx):
    clear_output(True)
    plt.imshow(x_t.reshape(8, 8), cmap='gray')
    plt.title('step %i' % t); plt.colorbar(); plt.show()
../_images/score_matching_24_0.png

This is just the beginning

In their paper, Song et al. (2019) propose a more sophisticated sampling procedure that can efficiently sample larger images. They also utilize a technique called Denoising Score Matching which can be safely ported even to earthling frameworks like tensorflow and pytorch. Go take a look!

celeba

Notebook author: Denis Mazur, edited by Just Heuristic