Interactive online version: Open In Colab

🔪 JAX - The Sharp Bits 🔪

levskaya@ mattjj@

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.

JAX is a language for expressing and composing transformations of numerical programs. As such it needs to control the unwanted proliferation of side-effects in its programs so that analysis and transformation of its computations remain tractable!

This requires us to write code in a functional style with explicit descriptions of how the state of a program changes, which results in several important differences to how you might be used to programming in Numpy, Tensorflow or Pytorch.

Herein we try to cover the most frequent points of trouble that users encounter when starting out in JAX.

[1]:
import numpy as onp
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False

🔪 In-Place Updates

In Numpy you’re used to doing this:

[2]:
numpy_array = onp.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

If we try to update a JAX device array in-place, however, we get an error! (☉_☉)

[3]:
jax_array = np.zeros((3,3), dtype=np.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-a717a200f584> in <module>
      2
      3 # In place update of JAX's array will yield an error!
----> 4 jax_array[1, :] = 1.0

~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/site-packages/jax/numpy/lax_numpy.py in _unimplemented_setitem(self, i, x)
   3233          "immutable; perhaps you want jax.ops.index_update or "
   3234          "jax.ops.index_add instead?")
-> 3235   raise TypeError(msg.format(type(self)))
   3236
   3237 _operators = {

TypeError: '<class 'jax.lax.lax._FilledConstant'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

What gives?!

Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.

Instead, JAX offers the functional update functions: index_update, index_add, index_min, index_max, and the index helper.

NB: Fancy Indexing is not yet supported, but will likely be added to JAX soon.

️⚠️ inside jit’d code and lax.while_loop or lax.fori_loop the size of slices can’t be functions of argument values but only functions of argument shapes – the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.

[4]:
from jax.ops import index, index_add, index_update

index_update

If the input values of index_update aren’t reused, jit-compiled code will perform these operations in-place.

[5]:
jax_array = np.zeros((3, 3))
print("original array:")
print(jax_array)

new_jax_array = index_update(jax_array, index[1, :], 1.)

print("old array unchanged:")
print(jax_array)

print("new array:")
print(new_jax_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/site-packages/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
old array unchanged:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
new array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

index_add

If the input values of index_update aren’t reused, jit-compiled code will perform these operations in-place.

[6]:
print("original array:")
jax_array = np.ones((5, 6))
print(jax_array)

new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]

🔪 Random Numbers

If all scientific papers whose results are in doubt because of bad ``rand()``s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist. - Numerical Recipes

RNGs and State

You’re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:

[7]:
print(onp.random.random())
print(onp.random.random())
print(onp.random.random())
0.11212922979678686
0.1564047889614152
0.8561243765676295

Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of \(2^{19937-1}\) and at any point can be described by 624 32bit unsigned ints and a position indicating how much of this “entropy” has been used up.

[8]:
onp.random.seed(0)
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers...,
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)

This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, “consuming” 2 of the uint32s in the Mersenne twister state vector:

[9]:
_ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = onp.random.uniform()
rng_state = onp.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)

The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.

The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5Kb state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.

JAX PRNG

JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Three-fry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.

The random state is described by two unsigned-int32s that we call a key:

[10]:
from jax import random
key = random.PRNGKey(0)
key
[10]:
DeviceArray([0, 0], dtype=uint32)

JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!

Reusing the same state will cause sadness and monotony, depriving the enduser of lifegiving chaos:

[11]:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584235]
[0 0]
[-0.20584235]
[0 0]

Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:

[12]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key [0 0]
    \---SPLIT --> new key    [4146024105  967050713]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515389]

We propagate the key and make new subkeys whenever we need a new random number:

[13]:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key [4146024105  967050713]
    \---SPLIT --> new key    [2384771982 3928867769]
             \--> new subkey [1278412471 2182328957] --> normal [-0.58665067]

We can generate more than one subkey at a time:

[14]:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))
[-0.3753345]
[0.9864503]
[0.1455319]

🔪 Control Flow

✔ python control_flow + autodiff ✔

If you just want to apply grad to your python functions, you can use regular python control-flow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).

[15]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!
12.0
-4.0

python control flow + JIT

Using control flow with jit is more complicated, and by default it has more constraints.

This works:

[16]:
@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))
24

So does this:

[17]:
@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(np.array([1., 2., 3.])))
6.0

But this doesn’t, at least by default:

[18]:
@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
try:
  f(2)
except Exception as e:
  print("ERROR:", e)
ERROR: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

What gives!?

When we jit-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to re-compile on each function evaluation.

For example, if we evaluate an @jit function on the array np.array([1., 2., 3.], np.float32), we might want to compile code that we can reuse to evaluate the function on np.array([4., 5., 6.], np.float32) to save on compile time.

To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.

By default, jit traces your code on the ShapedArray abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), np.float32), we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

But there’s a tradeoff here: if we trace a Python function on a ShapedArray((), np.float32) that isn’t committed to a specific concrete value, when we hit a line like if x < 3, the expression x < 3 evaluates to an abstract ShapedArray((), np.bool_) that represents the set {True, False}. When Python attempts to coerce that to a concrete True or False, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.

The good news is that you can control this tradeoff yourself. By having jit trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums argument to jit, we can specify to trace on concrete values of some arguments. Here’s that example function again:

[19]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))
12.0

Here’s another example, this time involving a loop:

[20]:
def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(np.array([2., 3., 4.]), 2)
[20]:
DeviceArray(5., dtype=float32)

In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped, but that’s not currently the default for any transformation

️⚠️ functions with argument-**value dependent shapes**

These control-flow issues also come up in a more subtle way: numerical functions we want to jit can’t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variable length.

[21]:
def example_fun(length, val):
  return np.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))

bad_example_jit = jit(example_fun)
# this will fail:
try:
  print(bad_example_jit(10, 4))
except Exception as e:
  print("error!", e)
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4.]
error! `full` requires shapes to be concrete. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]

static_argnums can be handy if length in our example rarely changes, but it would be disastrous if it changed a lot!

Lastly, if your function has global side-effects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit’d functions:

[22]:
@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)
Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>
Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>
[22]:
DeviceArray(4, dtype=int32)

Structured control flow primitives

There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that’s traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:

  • lax.cond will be differentiable soon
  • lax.while_loop non-differentiable*
  • lax.fori_loop non-differentiable*
  • lax.scan differentiable

*these can in principle be made to be **forward*-differentiable, but this isn’t on the current roadmap.*

cond

python equivalent:

def cond(pred, true_operand, true_fun, false_operand, false_fun):
  if pred:
    return true_fun(true_operand)
  else:
    return false_fun(false_operand)
[23]:
from jax import lax

operand = np.array([0.])
lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)
# --> array([1.], dtype=float32)
lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)
# --> array([-1.], dtype=float32)
[23]:
DeviceArray([-1.], dtype=float32)

while_loop

python equivalent:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
[24]:
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
[24]:
DeviceArray(10, dtype=int32)

fori_loop

python equivalent:

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
[25]:
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
[25]:
DeviceArray(45, dtype=int32)

Summary

\[\begin{split}\begin{array} {r|rr} \hline \ \textrm{construct} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & \textrm{soon!}\\ \textrm{lax.while_loop} & ✔ & ❌\\ \textrm{lax.fori_loop} & ✔ & ❌\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array}\end{split}\]

\(\ast\) = argument-value-independent loop condition - unrolls the loop

🔪 Convolutions

JAX and XLA offer the very general N-dimensional conv_general_dilated function, but it’s not very obvious how to use it. We’ll give some examples of the common use-cases. There are also the convenience functions lax.conv and lax.conv_general_padding for the most common kinds of convolutions.

A survey of the family of convolutional operators, a guide to convolutional arithmetic is highly recommended reading!

Let’s define a simple diagonal edge kernel:

[26]:
# 2D kernel - HWIO layout
kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)
kernel += onp.array([[1, 1, 0],
                [1, 0,-1],
                [0,-1,-1]])[:, :, onp.newaxis, onp.newaxis]

print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:
../_images/notebooks_Common_Gotchas_in_JAX_64_1.png

And we’ll make a simple synthetic image:

[27]:
# NHWC layout
img = onp.zeros((1, 200, 198, 3), dtype=np.float32)
for k in range(3):
    x = 30 + 60*k
    y = 20 + 60*k
    img[0, x:x+10, y:y+10, k] = 1.0

print("Original Image:")
plt.imshow(img[0]);
Original Image:
../_images/notebooks_Common_Gotchas_in_JAX_66_1.png

lax.conv and lax.conv_with_general_padding

These are the simple convenience functions for convolutions

️⚠️ The convenience lax.conv, lax.conv_with_general_padding helper function assume NCHW images and IOHW kernels.

[28]:
out = lax.conv(np.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
               np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
               (1, 1),  # window strides
               'SAME') # padding mode
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);
out shape:  (1, 3, 200, 198)
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_69_1.png
[29]:
out = lax.conv_with_general_padding(
  np.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
  np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
  (1, 1),  # window strides
  ((2,2),(2,2)), # general padding 2x2
  (1,1),  # lhs/image dilation
  (1,1))  # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);
out shape:  (1, 3, 202, 200)
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_70_1.png

Dimension Numbers define dimensional layout for conv_general_dilated

The important argument is the 3-tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout) - N - batch dimension - H - spatial height - W - spatial height - C - channel dimension - I - kernel input channel dimension - O - kernel output channel dimension

⚠️ To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated below.

[30]:
dn = lax.conv_dimension_numbers(img.shape,     # only ndim matters, not shape
                                kernel.shape,  # only ndim matters, not shape
                                ('NHWC', 'HWIO', 'NHWC'))  # the important bit
print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))

SAME padding, no stride, no dilation

[31]:
out = lax.conv_general_dilated(img,    # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,1),  # window strides
                               'SAME', # padding mode
                               (1,1),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);
out shape:  (1, 200, 198, 3)
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_74_1.png

VALID padding, no stride, no dilation

[32]:
out = lax.conv_general_dilated(img,     # lhs = image tensor
                               kernel,  # rhs = conv kernel tensor
                               (1,1),   # window strides
                               'VALID', # padding mode
                               (1,1),   # lhs/image dilation
                               (1,1),   # rhs/kernel dilation
                               dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);
out shape:  (1, 198, 196, 3) DIFFERENT from above!
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_76_1.png

SAME padding, 2,2 stride, no dilation

[33]:
out = lax.conv_general_dilated(img,    # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (2,2),  # window strides
                               'SAME', # padding mode
                               (1,1),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape:  (1, 100, 99, 3)  <-- half the size of above
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_78_1.png

VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)

[34]:
out = lax.conv_general_dilated(img,     # lhs = image tensor
                               kernel,  # rhs = conv kernel tensor
                               (1,1),   # window strides
                               'VALID', # padding mode
                               (1,1),   # lhs/image dilation
                               (12,12), # rhs/kernel dilation
                               dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape:  (1, 176, 174, 3)
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_80_1.png

VALID padding, no stride, lhs=input dilation ~ Transposed Convolution

[35]:
out = lax.conv_general_dilated(img,               # lhs = image tensor
                               kernel,            # rhs = conv kernel tensor
                               (1,1),             # window strides
                               ((0, 0), (0, 0)),  # padding mode
                               (2,2),             # lhs/image dilation
                               (1,1),             # rhs/kernel dilation
                               dn)                # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape:  (1, 397, 393, 3) <-- larger than original!
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_82_1.png

We can use the last to, for instance, implement transposed convolutions:

[36]:
# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))

# transposed conv = 180deg kernel roation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img,     # lhs = image tensor
                               kernel_rot,  # rhs = conv kernel tensor
                               (1,1),   # window strides
                               padding, # padding mode
                               (2,2),   # lhs/image dilation
                               (1,1),   # rhs/kernel dilation
                               dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape:  (1, 400, 396, 3) <-- transposed_conv
First output channel:
../_images/notebooks_Common_Gotchas_in_JAX_84_1.png

1D Convolutions

You aren’t limited to 2D convolutions, a simple 1D demo is below:

[37]:
# 1D kernel - WIO layout
kernel = onp.array([[[1, 0, -1], [-1,  0,  1]],
                    [[1, 1,  1], [-1, -1, -1]]],
                    dtype=np.float32).transpose([2,1,0])
# 1D data - NWC layout
data = onp.zeros((1, 200, 2), dtype=np.float32)
for i in range(2):
  for k in range(2):
      x = 35*i + 30 + 60*k
      data[0, x:x+30, k] = 1.0

print("in shapes:", data.shape, kernel.shape)

plt.figure(figsize=(10,5))
plt.plot(data[0]);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
                                ('NWC', 'WIO', 'NWC'))
print(dn)

out = lax.conv_general_dilated(data,   # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,),   # window strides
                               'SAME', # padding mode
                               (1,),   # lhs/image dilation
                               (1,),   # rhs/kernel dilation
                               dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape:  (1, 200, 2)
../_images/notebooks_Common_Gotchas_in_JAX_87_1.png
../_images/notebooks_Common_Gotchas_in_JAX_87_2.png

3D Convolutions

[38]:
# Random 3D kernel - HWDIO layout
kernel = onp.array([
  [[0, 0,  0], [0,  1,  0], [0,  0,   0]],
  [[0, -1, 0], [-1, 0, -1], [0,  -1,  0]],
  [[0, 0,  0], [0,  1,  0], [0,  0,   0]]],
  dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]

# 3D data - NHWDC layout
data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)
x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]

print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
                                ('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)

out = lax.conv_general_dilated(data,    # lhs = image tensor
                               kernel,  # rhs = conv kernel tensor
                               (1,1,1), # window strides
                               'SAME',  # padding mode
                               (1,1,1), # lhs/image dilation
                               (1,1,1), # rhs/kernel dilation
                               dn)      # dimension_numbers
print("out shape: ", out.shape)

# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
  my_cmap = cmap(np.arange(cmap.N))
  my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3
  return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape:  (1, 30, 30, 30, 1)
../_images/notebooks_Common_Gotchas_in_JAX_89_1.png
../_images/notebooks_Common_Gotchas_in_JAX_89_2.png

🔪 NaNs

Debugging NaNs

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by: - setting the JAX_DEBUG_NANS=True environment variable. - adding from jax.config import config and config.update("jax_debug_nans", True) near the top of your main file - adding from jax.config import config and config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True.

This will cause computations to error-out immediately on production of a NaN.

⚠️ You shouldn’t have the NaN-checker on if you’re not debugging, as it can introduce lots of device-host round-trips and performance regressions!

Double (64bit) precision

At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!

[39]:
x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype
[39]:
dtype('float32')

To use double-precision numbers, you need to set the jax_enable_x64 configuration variable at startup.

There are a few ways to do this:

  1. You can enable 64bit mode by setting the environment variable JAX_ENABLE_X64=True.
  2. You can manually set the jax_enable_x64 configuration flag at startup:
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
  1. You can parse command-line flags with absl.app.run(main)
from jax.config import config
config.config_with_absl()
  1. If you want JAX to run absl parsing for you, i.e. you don’t want to do absl.app.run(main), you can instead use
from jax.config import config
if __name__ == '__main__':
  # calls config.config_with_absl() *and* runs absl parsing
  config.parse_flags_with_absl()

Note that #2-#4 work for any of JAX’s configuration options.

We can then confirm that x64 mode is enabled:

[40]:
from jax import numpy as np, random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype # --> dtype('float64')
[40]:
dtype('float32')

Caveats

⚠️ XLA doesn’t support 64-bit convolutions on all backends!

Fin.

If something’s not covered here that has caused you weeping and gnashing of teeth, please let us know and we’ll extend these introductory advisos!