JAX Frequently Asked Questions (FAQ)

We are collecting here answers to frequently asked questions. Contributions welcome!

jit changes the behavior of my function

If you have a Python function that changes behavior after using jax.jit(), perhaps your function uses global state, or has side-effects. In the following code, the impure_func uses the global y and has a side-effect due to print:

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

Without jit the output is:

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

and with jit it is:

Inside: 0
Result: 0
Result: 1
Result: 2

For jax.jit(), the function is executed once using the Python interpreter, at which time the Inside printing happens, and the first value of y is observed. Then, the function is compiled and cached, and executed multiple times with different values of x, but with the same first value of y.

Additional reading:

jit decorated function is very slow to compile

If your jit decorated function takes tens of seconds (or more!) to run the first time you call it, but executes quickly when called again, JAX is taking a long time to trace or compile your code.

This is usually an symptom of calling your function generating a large amount of code in JAX’s internal representation, typically because it makes heavy use of Python control flow such as for loop. For a handful of loop iterations Python is OK, but if you need _many_ loop iterations, you should rewrite your code to make use of JAX’s structured control flow primitives (such as lax.scan()) or avoid wrapping the loop with jit (you can still use jit decorated functions inside the loop).

If you’re not sure if this is the problem, you can try running jax.make_jaxpr() on your function. You can expect slow compilation if the output is many hundreds or thousands of lines long.

Sometimes it isn’t obvious how to rewrite your code to avoid Python loops because your code makes use of many arrays with different shapes. The recommended solution in this case is to make use of functions like jax.numpy.where() to do your computation on padded arrays with fixed shape. The JAX team is exploring a “masking” transformation to make such code easier to write.

If your functions are slow to compile for another reason, please open an issue on GitHub.

Controlling data and computation placement on devices

Let’s first look at the principles of data and computation placement in JAX.

In JAX, the computation follows data placement. JAX arrays have two placement properties: 1) the device where the data resides; and 2) whether it is committed to the device or not (the data is sometimes referred to as being sticky to the device).

By default, JAX arrays are placed uncommitted on the default device (jax.devices()[0]), which is the first GPU by default. If no GPU is present, jax.devices()[0] is the first CPU. The default device can be set to “cpu” or “gpu” manually by setting the environment variable JAX_PLATFORM_NAME or the absl flag --jax_platform_name.

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device())  
gpu:0

Computations involving uncommitted data are performed on the default device and the results are uncommitted on the default device.

Data can also be placed explicitly on a device using jax.device_put() with a device parameter, in which case the data becomes committed to the device:

>>> import jax
>>> from jax import device_put
>>> print(device_put(1, jax.devices()[2]).device_buffer.device())  
gpu:2

Computations involving some committed inputs will happen on the committed device and the result will be committed on the same device. Invoking an operation on arguments that are committed to more than one device will raise an error.

You can also use jax.device_put() without a device parameter. If the data is already on a device (committed or not), it’s left as-is. If the data isn’t on any device—that is, it’s a regular Python or NumPy value—it’s placed uncommitted on the default device.

Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device.

jnp.device_put(jnp.zeros(...), jax.devices()[1]) or similar will actually create the array of zeros on jax.devices()[1], instead of creating the array on the default device then moving it. This is thanks to some laziness in array creation, which holds for all the constant creation operations (ones, full, eye, etc).

(As of April 2020, jax.jit() has a device parameter that affects the device placement. That parameter is experimental, is likely to be removed or changed, and its use is not recommended.)

For a worked-out example, we recommend reading through test_computation_follows_data in multi_device_test.py.

Abstract tracer value encountered where concrete value is expected error

If you are getting an error that a library function is called with “Abstract tracer value encountered where concrete value is expected”, you may need to change how you invoke JAX transformations. Below is an example and a couple of possible solutions, followed by the details of what is actually happening, if you are curious or the simple solution does not work for you.

Some library functions take arguments that specify shapes or axes, such as the second and third arguments for jax.numpy.split():

# def np.split(arr, num_sections: Union[int, Sequence[int]], axis: int):
np.split(np.zeros(2), 2, 0)  # works

If you try the following code:

jax.jit(np.split)(np.zeros(4), 2, 0)

you will get the following error:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in jax.numpy.split argument 1).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(int32[], weak_type=True):JaxprTrace(level=-1/1)>

You must change the way you use jax.jit() to ensure that the num_sections and axis arguments use their concrete values (2 and 0 respectively). The best mechanism is to use special transformation parameters to declare some arguments to be static, e.g., static_argnums for jax.jit():

jax.jit(np.split, static_argnums=(1, 2))(np.zeros(4), 2, 0)

An alternative is to apply the transformation to a closure that encapsulates the arguments to be protected, either manually as below or by using functools.partial:

jax.jit(lambda arr: np.split(arr, 2, 0))(np.zeros(4))

Note a new closure is created at every invocation, which defeats the compilation caching mechanism, which is why static_argnums is preferred.

To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read Different kinds of JAX values.

Different kinds of JAX values

In the process of transforming functions, JAX replaces some function arguments with special tracer values.

You could see this if you use a print statement:

def func(x):
  print(x)
  return np.cos(x)

res = jax.jit(func)(0.)

The above code does return the correct value 1. but it also prints Traced<ShapedArray(float32[])> for the value of x. Normally, JAX handles these tracer values internally in a transparent way, e.g., in the numeric JAX primitives that are used to implement the jax.numpy functions. This is why np.cos works in the example above.

More precisely, a tracer value is introduced for the argument of a JAX-transformed function, except the arguments identified by special parameters such as static_argnums for jax.jit() or static_broadcasted_argnums for jax.pmap(). Typically, computations that involve at least a tracer value will produce a tracer value. Besides tracer values, there are regular Python values: values that are computed outside JAX transformations, or arise from above-mentioned static arguments of certain JAX transformations, or computed solely from other regular Python values. These are the values that are used everywhere in absence of JAX transformations.

A tracer value carries an abstract value, e.g., ShapedArray with information about the shape and dtype of an array. We will refer here to such tracers as abstract tracers. Some tracers, e.g., those that are introduced for arguments of autodiff transformations, carry ConcreteArray abstract values that actually include the regular array data, and are used, e.g., for resolving conditionals. We will refer here to such tracers as concrete tracers. Tracer values computed from these concrete tracers, perhaps in combination with regular values, result in concrete tracers. A concrete value is either a regular value or a concrete tracer.

Most often values computed from tracer values are themselves tracer values. There are very few exceptions, when a computation can be entirely done using the abstract value carried by a tracer, in which case the result can be a regular value. For example, getting the shape of a tracer with ShapedArray abstract value. Another example, is when explicitly casting a concrete tracer value to a regular type, e.g., int(x) or x.astype(float). Another such situation is for bool(x), which produces a Python bool when concreteness makes it possible. That case is especially salient because of how often it arises in control flow.

Here is how the transformations introduce abstract or concrete tracers:

  • jax.jit(): introduces abstract tracers for all positional arguments except those denoted by static_argnums, which remain regular values.

  • jax.pmap(): introduces abstract tracers for all positional arguments except those denoted by static_broadcasted_argnums.

  • jax.vmap(), jax.make_jaxpr(), xla_computation(): introduce abstract tracers for all positional arguments.

  • jax.jvp() and jax.grad() introduce concrete tracers for all positional arguments. An exception is when these transformations are within an outer transformation and the actual arguments are themselves abstract tracers; in that case, the tracers introduced by the autodiff transformations are also abstract tracers.

  • All higher-order control-flow primitives (lax.cond(), lax.while_loop(), lax.fori_loop(), lax.scan()) when they process the functionals introduce abstract tracers, whether or not there is a JAX transformation in progress.

All of this is relevant when you have code that can operate only on regular Python values, such as code that has conditional control-flow based on data:

def divide(x, y):
  return x / y if y >= 1. else 0.

If we want to apply jax.jit(), we must ensure to specify static_argnums=1 to ensure y stays a regular value. This is due to the boolean expression y >= 1., which requires concrete values (regular or tracers). The same would happen if we write explicitly bool(y >= 1.), or int(y), or float(y).

Interestingly, jax.grad(divide)(3., 2.), works because jax.grad() uses concrete tracers, and resolves the conditional using the concrete value of y.

Gradients contain NaN where using where

If you define a function using where to avoid an undefined value, if you are not careful you may obtain a NaN for reverse differentiation:

def my_log(x):
  return np.where(x > 0., np.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

A short explanation is that during grad computation the adjoint corresponding to the undefined np.log(x) is a NaN and when it gets accumulated to the adjoint of the np.where. The correct way to write such functions is to ensure that there is a np.where inside the partially-defined function, to ensure that the adjoint is always finite:

def safe_for_grad_log(x):
  return np.log(np.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

The inner np.where may be needed in addition to the original one, e.g.:

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return np.where(x > 0., np.log(np.where(x > 0., x, 1.), y)

Additional reading: