# 🔪 JAX - The Sharp Bits 🔪¶

levskaya@ mattjj@

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.

JAX is a language for expressing and composing transformations of numerical programs. As such it needs to control the unwanted proliferation of side-effects in its programs so that analysis and transformation of its computations remain tractable!

This requires us to write code in a functional style with explicit descriptions of how the state of a program changes, which results in several important differences to how you might be used to programming in Numpy, Tensorflow or Pytorch.

Herein we try to cover the most frequent points of trouble that users encounter when starting out in JAX.

:

import numpy as onp
from jax import lax
from jax import random
import jax
import jax.numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False


In Numpy you’re used to doing this:

:

numpy_array = onp.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]


If we try to update a JAX device array in-place, however, we get an error! (☉_☉)

:

jax_array = np.zeros((3,3), dtype=np.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-a717a200f584> in <module>
2
3 # In place update of JAX's array will yield an error!
----> 4 jax_array[1, :] = 1.0

3233          "immutable; perhaps you want jax.ops.index_update or "
-> 3235   raise TypeError(msg.format(type(self)))
3236
3237 _operators = {

TypeError: '<class 'jax.lax.lax._FilledConstant'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?


What gives?!

Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.

Instead, JAX offers the functional update functions: index_update, index_add, index_min, index_max, and the index helper.

NB: Fancy Indexing is not yet supported, but will likely be added to JAX soon.

️⚠️ inside jit’d code and lax.while_loop or lax.fori_loop the size of slices can’t be functions of argument values but only functions of argument shapes – the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.

:

from jax.ops import index, index_add, index_update


### index_update¶

If the input values of index_update aren’t reused, jit-compiled code will perform these operations in-place.

:

jax_array = np.zeros((3, 3))
print("original array:")
print(jax_array)

new_jax_array = index_update(jax_array, index[1, :], 1.)

print("old array unchanged:")
print(jax_array)

print("new array:")
print(new_jax_array)

original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]

/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.7/site-packages/jax/lib/xla_bridge.py:120: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')

old array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
new array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]


If the input values of index_update aren’t reused, jit-compiled code will perform these operations in-place.

:

print("original array:")
jax_array = np.ones((5, 6))
print(jax_array)

new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
print(new_jax_array)

original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]


## 🔪 Random Numbers¶

If all scientific papers whose results are in doubt because of bad rand()s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist. - Numerical Recipes

### RNGs and State¶

You’re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:

:

print(onp.random.random())
print(onp.random.random())
print(onp.random.random())

0.11212922979678686
0.1564047889614152
0.8561243765676295


Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of $$2^{19937-1}$$ and at any point can be described by 624 32bit unsigned ints and a position indicating how much of this “entropy” has been used up.

:

onp.random.seed(0)
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers...,
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)


This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, “consuming” 2 of the uint32s in the Mersenne twister state vector:

:

_ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = onp.random.uniform()
rng_state = onp.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)


The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.

The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5Kb state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.

### JAX PRNG¶

JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Three-fry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.

The random state is described by two unsigned-int32s that we call a key:

:

from jax import random
key = random.PRNGKey(0)
key

:

DeviceArray([0, 0], dtype=uint32)


JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!

Reusing the same state will cause sadness and monotony, depriving the enduser of lifegiving chaos:

:

print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)

[-0.20584235]
[0 0]
[-0.20584235]
[0 0]


Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:

:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

old key [0 0]
\---SPLIT --> new key    [4146024105  967050713]
\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]


We propagate the key and make new subkeys whenever we need a new random number:

:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)

old key [4146024105  967050713]
\---SPLIT --> new key    [2384771982 3928867769]
\--> new subkey [1278412471 2182328957] --> normal [-0.58665067]


We can generate more than one subkey at a time:

:

key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))

[-0.3753345]
[0.9864503]
[0.1455319]


## 🔪 Control Flow¶

### ✔ python control_flow + autodiff ✔¶

If you just want to apply grad to your python functions, you can use regular python control-flow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).

:

def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x


12.0
-4.0


### python control flow + JIT¶

Using control flow with jit is more complicated, and by default it has more constraints.

This works:

:

@jit
def f(x):
for i in range(3):
x = 2 * x
return x

print(f(3))

24


So does this:

:

@jit
def g(x):
y = 0.
for i in range(x.shape):
y = y + x[i]
return y

print(g(np.array([1., 2., 3.])))

6.0


But this doesn’t, at least by default:

:

@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x

# This will fail!
try:
f(2)
except Exception as e:
print("ERROR:", e)

ERROR: Abstract value passed to bool, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using jit, try using static_argnums or applying jit to smaller subfunctions instead.


What gives!?

When we jit-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to re-compile on each function evaluation.

For example, if we evaluate an @jit function on the array np.array([1., 2., 3.], np.float32), we might want to compile code that we can reuse to evaluate the function on np.array([4., 5., 6.], np.float32) to save on compile time.

To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.

By default, jit traces your code on the ShapedArray abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), np.float32), we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

But there’s a tradeoff here: if we trace a Python function on a ShapedArray((), np.float32) that isn’t committed to a specific concrete value, when we hit a line like if x < 3, the expression x < 3 evaluates to an abstract ShapedArray((), np.bool_) that represents the set {True, False}. When Python attempts to coerce that to a concrete True or False, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.

The good news is that you can control this tradeoff yourself. By having jit trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums argument to jit, we can specify to trace on concrete values of some arguments. Here’s that example function again:

:

def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))

12.0


Here’s another example, this time involving a loop:

:

def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y

f = jit(f, static_argnums=(1,))

f(np.array([2., 3., 4.]), 2)

:

DeviceArray(5., dtype=float32)


In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped, but that’s not currently the default for any transformation

️⚠️ functions with argument-**value dependent shapes**

These control-flow issues also come up in a more subtle way: numerical functions we want to jit can’t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variable length.

:

def example_fun(length, val):
return np.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))

# this will fail:
try:
except Exception as e:
print("error!", e)
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))

[4. 4. 4. 4. 4.]
error! full requires shapes to be concrete. If using jit, try using static_argnums or applying jit to smaller subfunctions instead.
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]


static_argnums can be handy if length in our example rarely changes, but it would be disastrous if it changed a lot!

Lastly, if your function has global side-effects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit’d functions:

:

@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)

Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>
Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>

:

DeviceArray(4, dtype=int32)


### Structured control flow primitives¶

There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that’s traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:

• lax.cond will be differentiable soon
• lax.while_loop non-differentiable*
• lax.fori_loop non-differentiable*
• lax.scan differentiable

*these can in principle be made to be **forward*-differentiable, but this isn’t on the current roadmap.*

#### cond¶

python equivalent:

def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)

:

from jax import lax

operand = np.array([0.])
lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)
# --> array([1.], dtype=float32)
lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)
# --> array([-1.], dtype=float32)

:

DeviceArray([-1.], dtype=float32)


#### while_loop¶

python equivalent:

def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val

:

init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)

:

DeviceArray(10, dtype=int32)


#### fori_loop¶

python equivalent:

def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val

:

init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)

:

DeviceArray(45, dtype=int32)


#### Summary¶

$\begin{split}\begin{array} {r|rr} \hline \ \textrm{construct} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & \textrm{soon!}\\ \textrm{lax.while_loop} & ✔ & ❌\\ \textrm{lax.fori_loop} & ✔ & ❌\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array}\end{split}$

$$\ast$$ = argument-value-independent loop condition - unrolls the loop

## 🔪 Convolutions¶

JAX and XLA offer the very general N-dimensional conv_general_dilated function, but it’s not very obvious how to use it. We’ll give some examples of the common use-cases. There are also the convenience functions lax.conv and lax.conv_general_padding for the most common kinds of convolutions.

A survey of the family of convolutional operators, a guide to convolutional arithmetic is highly recommended reading!

Let’s define a simple diagonal edge kernel:

:

# 2D kernel - HWIO layout
kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)
kernel += onp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,-1]])[:, :, onp.newaxis, onp.newaxis]

print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);

Edge Conv kernel: And we’ll make a simple synthetic image:

:

# NHWC layout
img = onp.zeros((1, 200, 198, 3), dtype=np.float32)
for k in range(3):
x = 30 + 60*k
y = 20 + 60*k
img[0, x:x+10, y:y+10, k] = 1.0

print("Original Image:")
plt.imshow(img);

Original Image: These are the simple convenience functions for convolutions

️⚠️ The convenience lax.conv, lax.conv_with_general_padding helper function assume NCHW images and IOHW kernels.

:

out = lax.conv(np.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1),  # window strides
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);

out shape:  (1, 3, 200, 198)
First output channel: :

out = lax.conv_with_general_padding(
np.transpose(img,[0,3,1,2]),    # lhs = NCHW image tensor
np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1),  # window strides
(1,1),  # lhs/image dilation
(1,1))  # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);

out shape:  (1, 3, 202, 200)
First output channel: ### Dimension Numbers define dimensional layout for conv_general_dilated¶

The important argument is the 3-tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout) - N - batch dimension - H - spatial height - W - spatial height - C - channel dimension - I - kernel input channel dimension - O - kernel output channel dimension

⚠️ To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated below.

:

dn = lax.conv_dimension_numbers(img.shape,     # only ndim matters, not shape
kernel.shape,  # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC'))  # the important bit
print(dn)

ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))


#### SAME padding, no stride, no dilation¶

:

out = lax.conv_general_dilated(img,    # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1),  # window strides
(1,1),  # lhs/image dilation
(1,1),  # rhs/kernel dilation
dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);

out shape:  (1, 200, 198, 3)
First output channel: #### VALID padding, no stride, no dilation¶

:

out = lax.conv_general_dilated(img,     # lhs = image tensor
kernel,  # rhs = conv kernel tensor
(1,1),   # window strides
(1,1),   # lhs/image dilation
(1,1),   # rhs/kernel dilation
dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);

out shape:  (1, 198, 196, 3) DIFFERENT from above!
First output channel: #### SAME padding, 2,2 stride, no dilation¶

:

out = lax.conv_general_dilated(img,    # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2,2),  # window strides
(1,1),  # lhs/image dilation
(1,1),  # rhs/kernel dilation
dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);

out shape:  (1, 100, 99, 3)  <-- half the size of above
First output channel: #### VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)¶

:

out = lax.conv_general_dilated(img,     # lhs = image tensor
kernel,  # rhs = conv kernel tensor
(1,1),   # window strides
(1,1),   # lhs/image dilation
(12,12), # rhs/kernel dilation
dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);

out shape:  (1, 176, 174, 3)
First output channel: #### VALID padding, no stride, lhs=input dilation ~ Transposed Convolution¶

:

out = lax.conv_general_dilated(img,               # lhs = image tensor
kernel,            # rhs = conv kernel tensor
(1,1),             # window strides
((0, 0), (0, 0)),  # padding mode
(2,2),             # lhs/image dilation
(1,1),             # rhs/kernel dilation
dn)                # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);

out shape:  (1, 397, 393, 3) <-- larger than original!
First output channel: We can use the last to, for instance, implement transposed convolutions:

:

# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))

# transposed conv = 180deg kernel roation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img,     # lhs = image tensor
kernel_rot,  # rhs = conv kernel tensor
(1,1),   # window strides
(2,2),   # lhs/image dilation
(1,1),   # rhs/kernel dilation
dn)      # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);

out shape:  (1, 400, 396, 3) <-- transposed_conv
First output channel: ### 1D Convolutions¶

You aren’t limited to 2D convolutions, a simple 1D demo is below:

:

# 1D kernel - WIO layout
kernel = onp.array([[[1, 0, -1], [-1,  0,  1]],
[[1, 1,  1], [-1, -1, -1]]],
dtype=np.float32).transpose([2,1,0])
# 1D data - NWC layout
data = onp.zeros((1, 200, 2), dtype=np.float32)
for i in range(2):
for k in range(2):
x = 35*i + 30 + 60*k
data[0, x:x+30, k] = 1.0

print("in shapes:", data.shape, kernel.shape)

plt.figure(figsize=(10,5))
plt.plot(data);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NWC', 'WIO', 'NWC'))
print(dn)

out = lax.conv_general_dilated(data,   # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,),   # window strides
(1,),   # lhs/image dilation
(1,),   # rhs/kernel dilation
dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out);

in shapes: (1, 200, 2) (3, 2, 2)
ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))
out shape:  (1, 200, 2)  ### 3D Convolutions¶

:

# Random 3D kernel - HWDIO layout
kernel = onp.array([
[[0, 0,  0], [0,  1,  0], [0,  0,   0]],
[[0, -1, 0], [-1, 0, -1], [0,  -1,  0]],
[[0, 0,  0], [0,  1,  0], [0,  0,   0]]],
dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]

# 3D data - NHWDC layout
data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)
x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]

print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)

out = lax.conv_general_dilated(data,    # lhs = image tensor
kernel,  # rhs = conv kernel tensor
(1,1,1), # window strides
(1,1,1), # lhs/image dilation
(1,1,1), # rhs/kernel dilation
dn)      # dimension_numbers
print("out shape: ", out.shape)

# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3
return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');

in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)
ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))
out shape:  (1, 30, 30, 30, 1)  ## 🔪 NaNs¶

### Debugging NaNs¶

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by: - setting the JAX_DEBUG_NANS=True environment variable. - adding from jax.config import config and config.update("jax_debug_nans", True) near the top of your main file - adding from jax.config import config and config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True.

This will cause computations to error-out immediately on production of a NaN.

⚠️ You shouldn’t have the NaN-checker on if you’re not debugging, as it can introduce lots of device-host round-trips and performance regressions!

## Double (64bit) precision¶

At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!

:

x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype

:

dtype('float32')


To use double-precision numbers, you need to set the jax_enable_x64 configuration variable at startup.

There are a few ways to do this:

1. You can enable 64bit mode by setting the environment variable JAX_ENABLE_X64=True.
2. You can manually set the jax_enable_x64 configuration flag at startup:
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)

1. You can parse command-line flags with absl.app.run(main)
from jax.config import config
config.config_with_absl()

1. If you want JAX to run absl parsing for you, i.e. you don’t want to do absl.app.run(main), you can instead use
from jax.config import config
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()


Note that #2-#4 work for any of JAX’s configuration options.

We can then confirm that x64 mode is enabled:

:

from jax import numpy as np, random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype # --> dtype('float64')

:

dtype('float32')


### Caveats¶

⚠️ XLA doesn’t support 64-bit convolutions on all backends!

## Fin.¶

If something’s not covered here that has caused you weeping and gnashing of teeth, please let us know and we’ll extend these introductory advisos!