# Convolutions in JAX¶ JAX provides a number of interfaces to compute convolutions across data, including:

For basic convolution operations, the jax.numpy and jax.scipy operations are usually sufficient. If you want to do more general batched multi-dimensional convolution, the jax.lax function is where you should start.

## Basic One-dimensional Convolution¶

Basic one-dimensional convolution is implemented by jax.numpy.convolve(), which provides a JAX interface for numpy.convolve(). Here is a simple example of 1D smoothing implemented via a convolution:

import matplotlib.pyplot as plt

from jax import random
import jax.numpy as jnp
import numpy as np

key = random.PRNGKey(1701)

x = jnp.linspace(0, 10, 500)
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))

window = jnp.ones(10) / 10
y_smooth = jnp.convolve(y, window, mode='same')

plt.plot(x, y, 'lightgray')
plt.plot(x, y_smooth, 'black');

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) The mode parameter controls how boundary conditions are treated; here we use mode='same' to ensure that the output is the same size as the input.

For more information, see the jax.numpy.convolve() documentation, or the documentation associated with the original numpy.convolve() function.

## Basic N-dimensional Convolution¶

For N-dimensional convolution, jax.scipy.signal.convolve() provides a similar interface to that of jax.numpy.convolve(), generalized to N dimensions.

For example, here is a simple approach to de-noising an image based on convolution with a Gaussian filter:

from scipy import misc
import jax.scipy as jsp

fig, ax = plt.subplots(1, 3, figsize=(12, 5))

# Load a sample image; compute mean() to convert from RGB to grayscale.
image = jnp.array(misc.face().mean(-1))
ax.imshow(image, cmap='binary_r')
ax.set_title('original')

# Create a noisy version by adding random Gausian noise
key = random.PRNGKey(1701)
noisy_image = image + 50 * random.normal(key, image.shape)
ax.imshow(noisy_image, cmap='binary_r')
ax.set_title('noisy')

# Smooth the noisy image with a 2D Gaussian smoothing kernel.
x = jnp.linspace(-3, 3, 7)
window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None])
smooth_image = jsp.signal.convolve(noisy_image, window, mode='same')
ax.imshow(smooth_image, cmap='binary_r')
ax.set_title('smoothed'); Like in the one-dimensional case, we use mode='same' to specify how we would like edges to be handled. For more information on available options in N-dimensional convolutions, see the jax.scipy.signal.convolve() documentation.

## General Convolutions¶

For the more general types of batched convolutions often useful in the context of building deep neural networks, JAX and XLA offer the very general N-dimensional conv_general_dilated function, but it’s not very obvious how to use it. We’ll give some examples of the common use-cases.

A survey of the family of convolutional operators, a guide to convolutional arithmetic, is highly recommended reading!

Let’s define a simple diagonal edge kernel:

# 2D kernel - HWIO layout
kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)
kernel += jnp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis]

print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);

Edge Conv kernel: And we’ll make a simple synthetic image:

# NHWC layout
img = jnp.zeros((1, 200, 198, 3), dtype=jnp.float32)
for k in range(3):
x = 30 + 60*k
y = 20 + 60*k
img = img.at[0, x:x+10, y:y+10, k].set(1.0)

print("Original Image:")
plt.imshow(img);

Original Image: These are the simple convenience functions for convolutions

️⚠️ The convenience lax.conv, lax.conv_with_general_padding helper function assume NCHW images and OIHW kernels.

from jax import lax
out = lax.conv(jnp.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor
(1, 1),  # window strides
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);

out shape:  (1, 3, 200, 198)
First output channel: out = lax.conv_with_general_padding(
jnp.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1),  # window strides
(1,1),  # lhs/image dilation
(1,1))  # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,0,:,:]);

out shape:  (1, 3, 202, 200)
First output channel: ### Dimension Numbers define dimensional layout for conv_general_dilated¶

The important argument is the 3-tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout)

• N - batch dimension

• H - spatial height

• W - spatial height

• C - channel dimension

• I - kernel input channel dimension

• O - kernel output channel dimension

⚠️ To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated below.

dn = lax.conv_dimension_numbers(img.shape,     # only ndim matters, not shape
kernel.shape,  # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC'))  # the important bit
print(dn)

ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))


#### SAME padding, no stride, no dilation¶

out = lax.conv_general_dilated(img,    # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1),  # window strides
(1,1),  # lhs/image dilation
(1,1),  # rhs/kernel dilation
dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);

out shape:  (1, 200, 198, 3)
First output channel: #### VALID padding, no stride, no dilation¶

out = lax.conv_general_dilated(img,     # lhs = image tensor
kernel,  # rhs = conv kernel tensor
(1,1),   # window strides
(1,1),   # lhs/image dilation
(1,1),   # rhs/kernel dilation
dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(np.array(out)[0,:,:,0]);

out shape:  (1, 198, 196, 3) DIFFERENT from above!
First output channel: #### SAME padding, 2,2 stride, no dilation¶

out = lax.conv_general_dilated(img,    # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2,2),  # window strides
(1,1),  # lhs/image dilation
(1,1),  # rhs/kernel dilation
dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);

out shape:  (1, 100, 99, 3)  <-- half the size of above
First output channel: #### VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)¶

out = lax.conv_general_dilated(img,     # lhs = image tensor
kernel,  # rhs = conv kernel tensor
(1,1),   # window strides
(1,1),   # lhs/image dilation
(12,12), # rhs/kernel dilation
dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);

out shape:  (1, 176, 174, 3)
First output channel: #### VALID padding, no stride, lhs=input dilation ~ Transposed Convolution¶

out = lax.conv_general_dilated(img,               # lhs = image tensor
kernel,            # rhs = conv kernel tensor
(1,1),             # window strides
((0, 0), (0, 0)),  # padding mode
(2,2),             # lhs/image dilation
(1,1),             # rhs/kernel dilation
dn)                # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);

out shape:  (1, 397, 393, 3) <-- larger than original!
First output channel: We can use the last to, for instance, implement transposed convolutions:

# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))

# transposed conv = 180deg kernel roation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img,     # lhs = image tensor
kernel_rot,  # rhs = conv kernel tensor
(1,1),   # window strides
(2,2),   # lhs/image dilation
(1,1),   # rhs/kernel dilation
dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(np.array(out)[0,:,:,0]);

out shape:  (1, 400, 396, 3) <-- transposed_conv
First output channel: ### 1D Convolutions¶

You aren’t limited to 2D convolutions, a simple 1D demo is below:

# 1D kernel - WIO layout
kernel = jnp.array([[[1, 0, -1], [-1,  0,  1]],
[[1, 1,  1], [-1, -1, -1]]],
dtype=jnp.float32).transpose([2,1,0])
# 1D data - NWC layout
data = np.zeros((1, 200, 2), dtype=jnp.float32)
for i in range(2):
for k in range(2):
x = 35*i + 30 + 60*k
data[0, x:x+30, k] = 1.0

print("in shapes:", data.shape, kernel.shape)

plt.figure(figsize=(10,5))
plt.plot(data);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NWC', 'WIO', 'NWC'))
print(dn)

out = lax.conv_general_dilated(data,   # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,),   # window strides
(1,),   # lhs/image dilation
(1,),   # rhs/kernel dilation
dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out);

in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape:  (1, 200, 2)  ### 3D Convolutions¶

import matplotlib as mpl

# Random 3D kernel - HWDIO layout
kernel = jnp.array([
[[0, 0,  0], [0,  1,  0], [0,  0,   0]],
[[0, -1, 0], [-1, 0, -1], [0,  -1,  0]],
[[0, 0,  0], [0,  1,  0], [0,  0,   0]]],
dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]

# 3D data - NHWDC layout
data = jnp.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)
x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (jnp.sin(2*x*jnp.pi)*jnp.cos(2*y*jnp.pi)*jnp.cos(2*z*jnp.pi))[None,:,:,:,None]

print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)

out = lax.conv_general_dilated(data,    # lhs = image tensor
kernel,  # rhs = conv kernel tensor
(1,1,1), # window strides
(1,1,1), # lhs/image dilation
(1,1,1), # rhs/kernel dilation
dn)      # dimension_numbers
print("out shape: ", out.shape)

# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(jnp.arange(cmap.N))
my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3
return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');

in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape:  (1, 30, 30, 30, 1)

/tmp/ipykernel_2478/1945033479.py:37: MatplotlibDeprecationWarning: Calling gca() with keyword arguments was deprecated in Matplotlib 3.4. Starting two minor releases later, gca() will take no keyword arguments. The gca() function should only be used to get the current axes, or if no axes exist, create new axes with default keyword arguments. To create a new axes with non-default arguments, use plt.axes() or plt.subplot().
ax = fig.gca(projection='3d')
/tmp/ipykernel_2478/1945033479.py:42: MatplotlibDeprecationWarning: Calling gca() with keyword arguments was deprecated in Matplotlib 3.4. Starting two minor releases later, gca() will take no keyword arguments. The gca() function should only be used to get the current axes, or if no axes exist, create new axes with default keyword arguments. To create a new axes with non-default arguments, use plt.axes() or plt.subplot().
ax = fig.gca(projection='3d')  