🔪 JAX - The Sharp Bits 🔪

Open in Colab

levskaya@ mattjj@

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.

JAX is a language for expressing and composing transformations of numerical programs. JAX is also able to compile numerical programs for CPU or accelerators (GPU/TPU). JAX works great for many numerical and scientific programs, but only if they are written with certain constraints that we describe below.

import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False

🔪 Pure functions

JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.

Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect 
  return x

# The side-effects appear during the first run  
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call:  4.0
Second call:  5.0
Third call, different type:  [14.]
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value
First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))
50.0

It is not recommended to use iterators in any JAX function you want to jit or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.

import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)  
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))    
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0

🔪 In-Place Updates

In Numpy you’re used to doing this:

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

If we try to update a JAX device array in-place, however, we get an error! (☉_☉)

jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
try:
  jax_array[1, :] = 1.0
except Exception as e:
  print("Exception {}".format(e))
Exception '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/jax.ops.html

What gives?!

Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.

Instead, JAX offers the functional update functions: index_update, index_add, index_min, index_max, and the index helper.

️⚠️ inside jit’d code and lax.while_loop or lax.fori_loop the size of slices can’t be functions of argument values but only functions of argument shapes – the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.

from jax.ops import index, index_add, index_update

index_update

If the input values of index_update aren’t reused, jit-compiled code will perform these operations in-place.

jax_array = jnp.zeros((3, 3))
print("original array:")
print(jax_array)

new_jax_array = index_update(jax_array, index[1, :], 1.)

print("old array unchanged:")
print(jax_array)

print("new array:")
print(new_jax_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
old array unchanged:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
new array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

index_add

If the input values of index_update aren’t reused, jit-compiled code will perform these operations in-place.

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]

🔪 Out-of-Bounds Indexing

In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:

try:
  np.arange(10)[11]
except Exception as e:
  print("Exception {}".format(e))
Exception index 11 is out of bounds for axis 0 with size 10

However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in NaN). When the indexing operation is an array index update (e.g. index_add or scatter-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or gather-like primitives) the index is clamped to the bounds of the array since something must be returned. For example, the last value of the array will be returned from this indexing operation:

jnp.arange(10)[11]
DeviceArray(9, dtype=int32)

Note that due to this behavior for index retrieval, functions like jnp.nanargmin and jnp.nanargmax return -1 for slices consisting of NaNs whereas Numpy would throw an error.

Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.

🔪 Non-array inputs: NumPy vs. JAX

NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:

np.sum([1, 2, 3])
6

JAX departs from this, generally returning a helpful error:

try:
  jnp.sum([1, 2, 3])
except TypeError as e:
  print(f"TypeError: {e}")
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.

For example, consider the following permissive version of jnp.sum that allows list inputs:

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)
DeviceArray(45, dtype=int32)

The output is what we would expect, but this hides potential performance issues under the hood. In JAX’s tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the permissive_sum function above:

make_jaxpr(permissive_sum)(x)
{ lambda  ; a b c d e f g h i j.
  let k = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] a
      l = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] b
      m = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] c
      n = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] d
      o = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] e
      p = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] f
      q = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] g
      r = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] h
      s = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] i
      t = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] j
      u = concatenate[ dimension=0 ] k l m n o p q r s t
      v = convert_element_type[ new_dtype=int32
                                weak_type=False ] u
      w = reduce_sum[ axes=(0,) ] v
  in (w,) }

Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.

If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:

jnp.sum(jnp.array(x))
DeviceArray(45, dtype=int32)

🔪 Random Numbers

If all scientific papers whose results are in doubt because of bad rand()s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist. - Numerical Recipes

RNGs and State

You’re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:

print(np.random.random())
print(np.random.random())
print(np.random.random())
0.033313390223279726
0.3621651975221786
0.2026328101736048

Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of \(2^{19937}-1\) and at any point can be described by 624 32bit unsigned ints and a position indicating how much of this “entropy” has been used up.

np.random.seed(0)
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers..., 
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)

This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, “consuming” 2 of the uint32s in the Mersenne twister state vector:

_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state) 
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state) 
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state) 
# --> ('MT19937', array([1499117434, 2949980591, 2242547484, 
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)

The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.

The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5Kb state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.

JAX PRNG

JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.

The random state is described by two unsigned-int32s that we call a key:

from jax import random
key = random.PRNGKey(0)
key
DeviceArray([0, 0], dtype=uint32)

JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!

Reusing the same state will cause sadness and monotony, depriving the end user of lifegiving chaos:

print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584235]
[0 0]
[-0.20584235]
[0 0]

Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key [0 0]
    \---SPLIT --> new key    [4146024105  967050713]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515389]

We propagate the key and make new subkeys whenever we need a new random number:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key [4146024105  967050713]
    \---SPLIT --> new key    [2384771982 3928867769]
             \--> new subkey [1278412471 2182328957] --> normal [-0.58665067]

We can generate more than one subkey at a time:

key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))
[-0.37533444]
[0.9864503]
[0.1455319]

🔪 Control Flow

✔ python control_flow + autodiff ✔

If you just want to apply grad to your python functions, you can use regular python control-flow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!
12.0
-4.0

python control flow + JIT

Using control flow with jit is more complicated, and by default it has more constraints.

This works:

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))
24

So does this:

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))
6.0

But this doesn’t, at least by default:

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
try:
  f(2)
except Exception as e:
  print("Exception {}".format(e))
Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_545/2259617891.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

What gives!?

When we jit-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to re-compile on each function evaluation.

For example, if we evaluate an @jit function on the array jnp.array([1., 2., 3.], jnp.float32), we might want to compile code that we can reuse to evaluate the function on jnp.array([4., 5., 6.], jnp.float32) to save on compile time.

To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.

By default, jit traces your code on the ShapedArray abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), jnp.float32), we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

But there’s a tradeoff here: if we trace a Python function on a ShapedArray((), jnp.float32) that isn’t committed to a specific concrete value, when we hit a line like if x < 3, the expression x < 3 evaluates to an abstract ShapedArray((), jnp.bool_) that represents the set {True, False}. When Python attempts to coerce that to a concrete True or False, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.

The good news is that you can control this tradeoff yourself. By having jit trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums argument to jit, we can specify to trace on concrete values of some arguments. Here’s that example function again:

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))
12.0

Here’s another example, this time involving a loop:

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2)
DeviceArray(5., dtype=float32)

In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped, but that’s not currently the default for any transformation

️⚠️ functions with argument-value dependent shapes

These control-flow issues also come up in a more subtle way: numerical functions we want to jit can’t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variable length.

def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))

bad_example_jit = jit(example_fun)
# this will fail:
try:
  print(bad_example_jit(10, 4))
except Exception as e:
  print("Exception {}".format(e))
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4.]
Exception Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]

static_argnums can be handy if length in our example rarely changes, but it would be disastrous if it changed a lot!

Lastly, if your function has global side-effects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit’d functions:

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
DeviceArray(4, dtype=int32, weak_type=True)

Structured control flow primitives

There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that’s traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:

  • lax.cond differentiable

  • lax.while_loop fwd-mode-differentiable

  • lax.fori_loop fwd-mode-differentiable

  • lax.scan differentiable

cond

python equivalent:

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
DeviceArray([-1.], dtype=float32)

while_loop

python equivalent:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
DeviceArray(10, dtype=int32, weak_type=True)

fori_loop

python equivalent:

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
DeviceArray(45, dtype=int32, weak_type=True)

Summary

\[\begin{split} \begin{array} {r|rr} \hline \ \textrm{construct} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & ✔\\ \textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array} \end{split}\]

\(\ast\) = argument-value-independent loop condition - unrolls the loop

🔪 NaNs

Debugging NaNs

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:

  • setting the JAX_DEBUG_NANS=True environment variable;

  • adding from jax.config import config and config.update("jax_debug_nans", True) near the top of your main file;

  • adding from jax.config import config and config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True;

This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit. For code under an @jit, the output of every @jit function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of @jit at a time.

There could be tricky situations that arise, like nans that only occur under a @jit but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute.

If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line env JAX_DEBUG_NANS=True ipython, then ran this:

In [1]: import jax.numpy as jnp

In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... stack trace ...

.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return DeviceArray(device_buffer, *result_shape)

FloatingPointError: invalid value

The nan generated was caught. By running %debug, we can get a post-mortem debugger. This also works with functions under @jit, as the example below shows.

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

When this code sees a nan in the output of an @jit function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with %debug to inspect all the values to figure out the error.

⚠️ You shouldn’t have the NaN-checker on if you’re not debugging, as it can introduce lots of device-host round-trips and performance regressions!

⚠️ The NaN-checker doesn’t work with pmap. To debug nans in pmap code, one thing to try is replacing pmap with vmap.

Double (64bit) precision

At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!

x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype
dtype('float32')

To use double-precision numbers, you need to set the jax_enable_x64 configuration variable at startup.

There are a few ways to do this:

  1. You can enable 64bit mode by setting the environment variable JAX_ENABLE_X64=True.

  2. You can manually set the jax_enable_x64 configuration flag at startup:

    # again, this only works on startup!
    from jax.config import config
    config.update("jax_enable_x64", True)
    
  3. You can parse command-line flags with absl.app.run(main)

    from jax.config import config
    config.config_with_absl()
    
  4. If you want JAX to run absl parsing for you, i.e. you don’t want to do absl.app.run(main), you can instead use

    from jax.config import config
    if __name__ == '__main__':
      # calls config.config_with_absl() *and* runs absl parsing
      config.parse_flags_with_absl()
    

Note that #2-#4 work for any of JAX’s configuration options.

We can then confirm that x64 mode is enabled:

import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
dtype('float32')

Caveats

⚠️ XLA doesn’t support 64-bit convolutions on all backends!

Fin.

If something’s not covered here that has caused you weeping and gnashing of teeth, please let us know and we’ll extend these introductory advisos!