# MAML Tutorial with JAXÂ¶

Eric Jang

Blog post: https://blog.evjang.com/2019/02/maml-jax.html

21 Feb 2019

Pedagogical tutorial for implementing Model-Agnostic Meta-Learning with JAXâ€™s awesome grad and vmap and jit operators.

## OverviewÂ¶

In this notebook weâ€™ll go through:

• how to take gradients, gradients of gradients.

• how to fit a sinusoid function with a neural network (and do auto-batching with vmap)

• how to implement MAML and check its numerics

• how to implement MAML for sinusoid task (single-task objective, batching task instances).

• extending MAML to handle batching at the task-level

### import jax.numpy (almost-drop-in for numpy) and gradient operators.
import jax.numpy as jnp
from jax import grad


## Gradients of GradientsÂ¶

JAX makes it easy to compute gradients of python functions. Here, we thrice-differentiate $$e^x$$ and $$x^2$$

f = lambda x : jnp.exp(x)
g = lambda x : jnp.square(x)
print(grad(f)(1.)) # = e^{1}
print(grad(grad(f))(1.))
print(grad(grad(grad(f)))(1.))

print(grad(g)(2.)) # 2x = 4
print(grad(grad(g))(2.)) # x = 2
print(grad(grad(grad(g)))(2.)) # x = 0

2.7182817
2.7182817
2.7182817
4.0
2.0
0.0


## Sinusoid Regression and vmapÂ¶

To get you familiar with JAX syntax first, weâ€™ll optimize neural network params with fixed inputs on a mean-squared error loss to $$f_\theta(x) = sin(x)$$.

from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization

# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
Dense(40), Relu,
Dense(40), Relu,
Dense(1)
)

rng = random.PRNGKey(0)
in_shape = (-1, 1,)
out_shape, net_params = net_init(rng, in_shape)

def loss(params, inputs, targets):
# Computes average loss for the batch
predictions = net_apply(params, inputs)
return jnp.mean((targets - predictions)**2)

# batch the inference across K=100
xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = jnp.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()

<matplotlib.legend.Legend at 0x7f6e724302b0>

import numpy as np
from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)
opt_state = opt_init(net_params)

# Define a compiled update step
@jit
def step(i, opt_state, x1, y1):
p = get_params(opt_state)
g = grad(loss)(p, x1, y1)
return opt_update(i, g, opt_state)

for i in range(100):
opt_state = step(i, opt_state, xrange_inputs, targets)
net_params = get_params(opt_state)

# batch the inference across K=100
targets = jnp.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()

<matplotlib.legend.Legend at 0x7f6e72d99080>


## MAML: Optimizing for GeneralizationÂ¶

Suppose task loss function $$\mathcal{L}$$ is defined with respect to model parameters $$\theta$$, input features $$X$$, input labels $$Y$$. MAML optimizes the following:

$$\mathcal{L}(\theta - \nabla \mathcal{L}(\theta, x_1, y_1), x_2, y_2)$$

$$x_1, y_2$$ and $$x_2, y_2$$ are identically distributed from $$X, Y$$. Therefore, MAML objective can be thought of as a differentiable cross-validation error (w.r.t. $$x_2, y_2$$) for a model that learns (via a single gradient descent step) from $$x_1, y_1$$. Minimizing cross-validation error provides an inductive bias on generalization.

The following toy example checks MAML numerics via parameter $$x$$ and input $$y$$.

# gradients of gradients test for MAML
# check numerics
g = lambda x, y : jnp.square(x) + y
x0 = 2.
y0 = 1.
print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4
print('x0 - grad(g)(x0) = {}'.format(x0 - grad(g)(x0, y0))) # x - 2x = -2
def maml_objective(x, y):
return g(x - grad(g)(x, y), y)
print('maml_objective(x,y)={}'.format(maml_objective(x0, y0))) # x**2 + 1 = 5
print('x0 - maml_objective(x,y) = {}'.format(x0 - grad(maml_objective)(x0, y0))) # x - (2x)

grad(g)(x0) = 4.0
x0 - grad(g)(x0) = -2.0
maml_objective(x,y)=5.0
x0 - maml_objective(x,y) = -2.0


## Sinusoid Task + MAMLÂ¶

Now letâ€™s re-implement the Sinusoidal regression task from Chelsea Finnâ€™s MAML paper.

alpha = .1
def inner_update(p, x1, y1):
grads = grad(loss)(p, x1, y1)
inner_sgd_fn = lambda g, state: (state - alpha*g)
return tree_multimap(inner_sgd_fn, grads, p)

def maml_loss(p, x1, y1, x2, y2):
p2 = inner_update(p, x1, y1)
return loss(p2, x2, y2)

x1 = xrange_inputs
y1 = targets
x2 = jnp.array([0.])
y2 = jnp.array([0.])
maml_loss(net_params, x1, y1, x2, y2)

DeviceArray(2.5027603e-05, dtype=float32)


Letâ€™s try minimizing the MAML loss (without batching across multiple tasks, which we will do in the next section)

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)  # this LR seems to be better than 1e-2 and 1e-4
out_shape, net_params = net_init(rng, in_shape)
opt_state = opt_init(net_params)

@jit
def step(i, opt_state, x1, y1, x2, y2):
p = get_params(opt_state)
g = grad(maml_loss)(p, x1, y1, x2, y2)
l = maml_loss(p, x1, y1, x2, y2)
return opt_update(i, g, opt_state), l
K=20

np_maml_loss = []

# Adam optimization
for i in range(20000):
# define the task
A = np.random.uniform(low=0.1, high=.5)
phase = np.random.uniform(low=0., high=jnp.pi)
# meta-training inner split (K examples)
x1 = np.random.uniform(low=-5., high=5., size=(K,1))
y1 = A * np.sin(x1 + phase)
# meta-training outer split (1 example). Like cross-validating with respect to one example.
x2 = np.random.uniform(low=-5., high=5.)
y2 = A * np.sin(x2 + phase)
opt_state, l = step(i, opt_state, x1, y1, x2, y2)
np_maml_loss.append(l)
if i % 1000 == 0:
print(i)
net_params = get_params(opt_state)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000

# batch the inference across K=100
targets = jnp.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='pre-update predictions')
plt.plot(xrange_inputs, targets, label='target')

x1 = np.random.uniform(low=-5., high=5., size=(K,1))
y1 = 1. * np.sin(x1 + 0.)

for i in range(1,5):
net_params = inner_update(net_params, x1, y1)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))
plt.legend()

<matplotlib.legend.Legend at 0x7f6e5ff89a58>


## Batching Meta-Gradient Across TasksÂ¶

Kind of does the job but not that great. Letâ€™s reduce the variance of gradients in outer loop by averaging across a batch of tasks (not just one task at a time).

vmap is awesome it enables nice handling of batching at two levels: inner-level â€śintra-taskâ€ť batching, and outer level batching across tasks.

From a software engineering perspective, it is nice because the â€śtask-batchedâ€ť MAML implementation simply re-uses code from the non-task batched MAML algorithm, without losing any vectorization benefits.

def sample_tasks(outer_batch_size, inner_batch_size):
# Select amplitude and phase for the task
As = []
phases = []
for _ in range(outer_batch_size):
As.append(np.random.uniform(low=0.1, high=.5))
phases.append(np.random.uniform(low=0., high=jnp.pi))
def get_batch():
xs, ys = [], []
for A, phase in zip(As, phases):
x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
y = A * np.sin(x + phase)
xs.append(x)
ys.append(y)
return jnp.stack(xs), jnp.stack(ys)
x1, y1 = get_batch()
x2, y2 = get_batch()
return x1, y1, x2, y2

outer_batch_size = 2
x1, y1, x2, y2 = sample_tasks(outer_batch_size, 50)
for i in range(outer_batch_size):
plt.scatter(x1[i], y1[i], label='task{}-train'.format(i))
for i in range(outer_batch_size):
plt.scatter(x2[i], y2[i], label='task{}-val'.format(i))
plt.legend()

<matplotlib.legend.Legend at 0x7f6e5ff63748>

x2.shape

(2, 50, 1)

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
out_shape, net_params = net_init(rng, in_shape)
opt_state = opt_init(net_params)

# vmapped version of maml loss.
# returns scalar for all tasks.
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
return jnp.mean(task_losses)

@jit
def step(i, opt_state, x1, y1, x2, y2):
p = get_params(opt_state)
g = grad(batch_maml_loss)(p, x1, y1, x2, y2)
l = batch_maml_loss(p, x1, y1, x2, y2)
return opt_update(i, g, opt_state), l

np_batched_maml_loss = []
K=20
for i in range(20000):
x1_b, y1_b, x2_b, y2_b = sample_tasks(4, K)
opt_state, l = step(i, opt_state, x1_b, y1_b, x2_b, y2_b)
np_batched_maml_loss.append(l)
if i % 1000 == 0:
print(i)
net_params = get_params(opt_state)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000

# batch the inference across K=100
targets = jnp.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='pre-update predictions')
plt.plot(xrange_inputs, targets, label='target')

x1 = np.random.uniform(low=-5., high=5., size=(10,1))
y1 = 1. * np.sin(x1 + 0.)

for i in range(1,3):
net_params = inner_update(net_params, x1, y1)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))
plt.legend()

<matplotlib.legend.Legend at 0x7f6e5ff63e10>

# Comparison of maml_loss for task batch size = 1 vs. task batch size = 8
plt.plot(np.convolve(np_maml_loss, [.05]*20), label='task_batch=1')
plt.plot(np.convolve(np_batched_maml_loss, [.05]*20), label='task_batch=4')
plt.ylim(0., 1e-1)
plt.legend()

<matplotlib.legend.Legend at 0x7f6e5f5d9710>