🔪 JAX - The Sharp Bits 🔪#
When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”.
JAX is a language for expressing and composing transformations of numerical programs. JAX is also able to compile numerical programs for CPU or accelerators (GPU/TPU). JAX works great for many numerical and scientific programs, but only if they are written with certain constraints that we describe below.
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
🔪 Pure functions#
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.
Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
50.0
It is not recommended to use iterators in any JAX function you want to jit
or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.
import jax.numpy as jnp
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0
🔪 In-place updates#
In Numpy you’re used to doing this:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
If we try to update a JAX device array in-place, however, we get an error! (☉_☉)
%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.
Instead, JAX offers a functional array update using the .at
property on JAX arrays.
️⚠️ inside jit
’d code and lax.while_loop
or lax.fori_loop
the size of slices can’t be functions of argument values but only functions of argument shapes – the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.
Array updates: x.at[idx].set(y)
#
For example, the update above can be written as:
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
JAX’s array update functions, unlike their NumPy versions, operate out-of-place. That is, the updated array is returned as a new array and the original array is not modified by the update.
print("original array unchanged:\n", jax_array)
original array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
However, inside jit-compiled code, if the input value x
of x.at[idx].set(y)
is not reused, the compiler will optimize the array update to occur in-place.
Array updates with other operations#
Indexed array updates are not limited simply to overwriting values. For example, we can perform indexed addition as follows:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
For more details on indexed array updates, see the documentation for the .at
property.
🔪 Out-of-bounds indexing#
In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:
np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10
However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in NaN
). When the indexing operation is an array index update (e.g. index_add
or scatter
-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or gather
-like primitives) the index is clamped to the bounds of the array since something must be returned. For example, the last value of the array will be returned from this indexing operation:
jnp.arange(10)[11]
Array(9, dtype=int32)
If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of ndarray.at
; for example:
jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)
Note that due to this behavior for index retrieval, functions like jnp.nanargmin
and jnp.nanargmax
return -1 for slices consisting of NaNs whereas Numpy would throw an error.
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.
🔪 Non-array inputs: NumPy vs. JAX#
NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:
np.sum([1, 2, 3])
np.int64(6)
JAX departs from this, generally returning a helpful error:
jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.
For example, consider the following permissive version of jnp.sum
that allows list inputs:
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)
The output is what we would expect, but this hides potential performance issues under the hood. In JAX’s tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the permissive_sum
function above:
make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
j:i32[]. let
k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m
x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o
z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p
ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r
bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t
be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
bf:i32[] = reduce_sum[axes=(0,)] be
in (bf,) }
Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.
If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:
jnp.sum(jnp.array(x))
Array(45, dtype=int32)
🔪 Random numbers#
If all scientific papers whose results are in doubt because of bad
rand()
s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist. - Numerical Recipes
RNGs and state#
You’re used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.01098694126301536
0.5379060587704946
0.3318452119910925
Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of \(2^{19937}-1\) and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this “entropy” has been used up.
np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, “consuming” 2 of the uint32s in the Mersenne twister state vector:
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
The problem with magic PRNG state is that it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end user.
The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5kB state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.
JAX PRNG#
JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by a special array element that we call a key:
key = random.key(0)
key
Array((), dtype=key<fry>) overlaying:
[0 0]
JAX’s random functions produce pseudorandom numbers from the PRNG state, but do not change the state!
Reusing the same state will cause sadness and monotony, depriving the end user of lifegiving chaos:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[0 0]
\---SPLIT --> new key Array((), dtype=key<fry>) overlaying:
[4146024105 967050713]
\--> new subkey Array((), dtype=key<fry>) overlaying:
[2718843009 1272950319] --> normal [-1.2515389]
We propagate the key and make new subkeys whenever we need a new random number:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r" \---SPLIT --> new key ", key)
print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[4146024105 967050713]
\---SPLIT --> new key Array((), dtype=key<fry>) overlaying:
[2384771982 3928867769]
\--> new subkey Array((), dtype=key<fry>) overlaying:
[1278412471 2182328957] --> normal [-0.58665055]
We can generate more than one subkey at a time:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
[-0.37533438]
[0.98645043]
[0.14553197]
🔪 Control flow#
✔ Python control_flow + autodiff ✔#
If you just want to apply grad
to your python functions, you can use regular python control-flow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
-4.0
Python control flow + JIT#
Using control flow with jit
is more complicated, and by default it has more constraints.
This works:
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
So does this:
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
6.0
But this doesn’t, at least by default:
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1265/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
What gives!?
When we jit
-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don’t have to re-compile on each function evaluation.
For example, if we evaluate an @jit
function on the array jnp.array([1., 2., 3.], jnp.float32)
, we might want to compile code that we can reuse to evaluate the function on jnp.array([4., 5., 6.], jnp.float32)
to save on compile time.
To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.
By default, jit
traces your code on the ShapedArray
abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), jnp.float32)
, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
But there’s a tradeoff here: if we trace a Python function on a ShapedArray((), jnp.float32)
that isn’t committed to a specific concrete value, when we hit a line like if x < 3
, the expression x < 3
evaluates to an abstract ShapedArray((), jnp.bool_)
that represents the set {True, False}
. When Python attempts to coerce that to a concrete True
or False
, we get an error: we don’t know which branch to take, and can’t continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.
The good news is that you can control this tradeoff yourself. By having jit
trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums
argument to jit
, we can specify to trace on concrete values of some arguments. Here’s that example function again:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
12.0
Here’s another example, this time involving a loop:
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)
In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped
, but that’s not currently the default for any transformation
️⚠️ functions with argument-value dependent shapes
These control-flow issues also come up in a more subtle way: numerical functions we want to jit can’t specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let’s make a function whose output happens to depend on the input variable length
.
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_1265/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
static_argnums
can be handy if length
in our example rarely changes, but it would be disastrous if it changed a lot!
Lastly, if your function has global side-effects, JAX’s tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit’d functions:
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Array(4, dtype=int32, weak_type=True)
Structured control flow primitives#
There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that’s traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:
lax.cond
differentiablelax.while_loop
fwd-mode-differentiablelax.fori_loop
fwd-mode-differentiable in general; fwd and rev-mode differentiable if endpoints are static.lax.scan
differentiable
cond
#
python equivalent:
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)
jax.lax
provides two other functions that allow branching on dynamic predicates:
lax.select
is like a batched version oflax.cond
, with the choices expressed as pre-computed arrays rather than as functions.lax.switch
is likelax.cond
, but allows switching between any number of callable choices.
In addition, jax.numpy
provides several numpy-style interfaces to these functions:
jnp.where
with three arguments is the numpy-style wrapper oflax.select
.jnp.piecewise
is a numpy-style wrapper oflax.switch
, but switches on a list of boolean conditions rather than a single scalar index.jnp.select
has an API similar tojnp.piecewise
, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls tolax.select
.
while_loop
#
python equivalent:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop
#
python equivalent:
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
Summary#
\(\ast\) = argument-value-independent loop condition - unrolls the loop
🔪 Dynamic shapes#
JAX code used within transforms like jax.jit
, jax.vmap
, jax.grad
, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.
For example, if you were implementing your own version of jnp.nansum
, you might start with something like this:
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
Outside JIT and other transforms, this works as expected:
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0
If you attempt to apply jax.jit
or another transform to this function, it will error:
jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
The problem is that the size of x_without_nans
is dependent on the values within x
, which is another way of saying its size is dynamic.
Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means.
For example, here it is possible to use the three-argument form of jnp.where
to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
10.0
Similar tricks can be played in other situations where dynamically-shaped arrays occur.
🔪 NaNs#
Debugging NaNs#
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
setting the
JAX_DEBUG_NANS=True
environment variable;adding
jax.config.update("jax_debug_nans", True)
near the top of your main file;adding
jax.config.parse_flags_with_absl()
to your main file, then set the option using a command-line flag like--jax_debug_nans=True
;
This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit
. For code under an @jit
, the output of every @jit
function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of @jit
at a time.
There could be tricky situations that arise, like nans that only occur under a @jit
but don’t get produced in de-optimized mode. In that case you’ll see a warning message print out but your code will continue to execute.
If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line env JAX_DEBUG_NANS=True ipython
, then ran this:
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return Array(device_buffer, *result_shape)
FloatingPointError: invalid value
The nan generated was caught. By running %debug
, we can get a post-mortem debugger. This also works with functions under @jit
, as the example below shows.
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
When this code sees a nan in the output of an @jit
function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with %debug
to inspect all the values to figure out the error.
⚠️ You shouldn’t have the NaN-checker on if you’re not debugging, as it can introduce lots of device-host round-trips and performance regressions!
⚠️ The NaN-checker doesn’t work with pmap
. To debug nans in pmap
code, one thing to try is replacing pmap
with vmap
.
🔪 Double (64bit) precision#
At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double
. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_1265/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
To use double-precision numbers, you need to set the jax_enable_x64
configuration variable at startup.
There are a few ways to do this:
You can enable 64-bit mode by setting the environment variable
JAX_ENABLE_X64=True
.You can manually set the
jax_enable_x64
configuration flag at startup:# again, this only works on startup! import jax jax.config.update("jax_enable_x64", True)
You can parse command-line flags with
absl.app.run(main)
import jax jax.config.config_with_absl()
If you want JAX to run absl parsing for you, i.e. you don’t want to do
absl.app.run(main)
, you can instead useimport jax if __name__ == '__main__': # calls jax.config.config_with_absl() *and* runs absl parsing jax.config.parse_flags_with_absl()
Note that #2-#4 work for any of JAX’s configuration options.
We can then confirm that x64
mode is enabled, for example:
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", True)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
Caveats#
⚠️ XLA doesn’t support 64-bit convolutions on all backends!
🔪 Miscellaneous divergences from NumPy#
While jax.numpy
makes every attempt to replicate the behavior of numpy’s API, there do exist corner cases where the behaviors differ.
Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.
For binary operations, JAX’s type promotion rules differ somewhat from those used by NumPy. See Type Promotion Semantics for more details.
When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX’s behavior may be backend dependent, and in general may diverge from NumPy’s behavior. Numpy allows control over the result in these scenarios via the
casting
argument (seenp.ndarray.astype
); JAX does not provide any such configuration, instead directly inheriting the behavior of XLA:ConvertElementType.Here is an example of an unsafe cast with differing results between NumPy and JAX:
>>> np.arange(254.0, 258.0).astype('uint8') array([254, 255, 0, 1], dtype=uint8) >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8)
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
Fin.#
If something’s not covered here that has caused you weeping and gnashing of teeth, please let us know and we’ll extend these introductory advisos!