Omnistaging#

mattjj@ Sept 25 2020

This is more of an upgrade guide than a design doc.

tl;dr#

What’s going on?#

A change to JAX’s tracing infrastructure called “omnistaging” (google/jax#3370) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term it’s best to fix the bugs, but omnistaging can also be disabled as a temporary workaround. And we’re happy to help you with fixes!

How do I know if omnistaging broke my code?#

The easiest way to tell if omnistaging is responsible is to disable omnistaging and see if the issues go away. See the What issues can arise when omnistaging is switched on? section below.

How can I disable omnistaging for now?#

Note: this applies to JAX versions 0.2.0 through 0.2.11; omnistaging cannot be disabled in JAX versions 0.2.12 and higher

It is temporarily possible to disable omnistaging by

  1. setting the shell environment variable JAX_OMNISTAGING to something falsey;

  2. setting the boolean flag jax_omnistaging to soething falsey if your code parses flags with absl;

  3. using this statement near the top of your main file:

jax.config.disable_omnistaging()

How do I fix bugs exposed by omnistaging?#

By far the most common issue with omnistaging is using jax.numpy to compute shape values or other trace-time constants. See the code block below for a quick example, and for full details along with other issues see the section What issues can arise when omnistaging is switched on?.

Instead of this:

@jit
def f(x):
  input_size = jnp.prod(x.shape)
  if input_size > 100:
    ...

do this:

import numpy as np

@jit
def f(x):
  input_size = np.prod(x.shape)
  if input_size > 100:
    ...

Instead of thinking of jax.numpy as a drop-in replacement for numpy, it’s now better to think of using jax.numpy operations only when you want to perform a computation on an accelerator (like your GPU).

What is “omnistaging” and why is it useful?#

Omnistaging is the name for a JAX core upgrade aimed at staging out more computation from op-by-op Python to XLA, and avoiding any “trace-time constant folding” in jit, pmap, and control flow primitives. As a result, omnistaging improves JAX’s memory performance (sometimes dramatically) both by reducing fragmentation during tracing and by producing fewer large compile-time constants for XLA. It can also improve tracing performance by eliminating op-by-op execution at tracing time. Further, omnistaging simplifies JAX core internals, fixing many outstanding bugs and setting the stage for important upcoming features.

The name “omnistaging” means staging out everything possible.

Toy example#

JAX transformations like jit and pmap stage out computations to XLA. That is, we apply them to functions comprising multiple primitive operations so that rather being executed one at a time from Python the operations are all part of one end-to-end optimized XLA computation.

But exactly which operations get staged out? Until omnistaging, JAX staged out computation based on data dependence only. Here’s an example function, followed by the XLA HLO program it stages out before the omnistaging change:

from jax import jit
import jax.numpy as jnp

@jit
def f(x):
  y = jnp.add(1, 1)
  return x * y

f(3)
ENTRY jit_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(parameter.1, constant.3)
  ROOT tuple.5 = (s32[]) tuple(multiply.4)
}

Notice that the add operation is not staged out. Instead, we only see a multiply.

Here’s the HLO generated from this function after the omnistaging change:

ENTRY jit_f.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(1)
  constant.4 = s32[] constant(1)
  add.5 = s32[] add(constant.3, constant.4)
  multiply.6 = s32[] multiply(parameter.1, add.5)
  ROOT tuple.7 = (s32[]) tuple(multiply.6)
}

Slightly less toy example#

Here’s a less toy example which can arise in practice when we want to create boolean masks:

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)

Before omnistaging:

ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

The select operation is staged out, but the operations for constructing the constant mask are not. Rather than being staged out, the operations that construct mask are executed op-by-op at Python tracing time, and XLA only sees a compile time constant constant.1 representing the value of mask. That’s unfortunate, because if we had staged out the operations for constructing mask, XLA could have fused them into the select and avoided materializing the result at all. As a result we end up wasting memory with a potentially-large constant, wasting time dispatching multiple un-fused op-by-op XLA computations, and potentially even fragmenting memory.

(The broadcast that corresponds to the construction of the zeros array for jnp.zeros_like(x) is staged out because JAX is lazy about very simple expressions from google/jax#1668. After omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)

The reason the creation of mask is not staged out is that, before omnistaging, jit operates based on data dependence. That is, jit stages out only those operations in a function that have a data dependence on an argument. Control flow primitives and pmap behave similarly. In the case of select_tril, the operations to construct the constant mask do not have a data dependence on the argument x, so they are not staged out; only the lax.select call has a data dependence.

With omnistaging all jax.numpy calls in the dynamic context of a jit-transformed function are staged out to XLA. That is, after omnistaging the computation XLA sees for select_tril is

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}

What issues can arise when omnistaging is switched on?#

As a consequence of staging out all jax.numpy operations from Python to XLA when in the dynamic context of a jit or pmap, some code that worked previously can start raising loud errors. As explained below, these behaviors were already buggy before omnistaging, but omnistaging makes them into hard errors.

Using jax.numpy for shape computations#

Example#

from jax import jit
import jax.numpy as jnp

@jit
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return x.reshape((size,))

ex1(jnp.ones((3, 4)))

Error message#

[... full traceback ...]
  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:

  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
    from line ex1.py:6 (ex1)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

Explanation#

With omnistaging, we can’t use jax.numpy for shape computations as in the use of jnp.prod above because in the dynamic context of a jit function those operations will be staged out of Python as values to be computed at execution time, yet we need them to be compile-time (and hence trace-time) constants.

Before omnistaging, this code wouldn’t have raised an error, but it was a common performance bug: the jnp.prod computation would have been executed on the device at tracing time, meaning extra compilation, transfers, synchronization, allocations, and potentially memory fragmentation.

Solution#

The solution is simply to use the original numpy for shape calculations like these. Not only do we avoid the error, but also we keep the computations on the host (and with lower overheads).

This issue was common enough in code that we tried to make the error message especially good. In addition to the stack trace showing where an abstract tracer value caused a problem (the jnp.reshape line in the full stack trace, on omni.py:10), we also explain why this value became a tracer in the first place by pointing to the upstream primitive operation that caused it to become an abstract tracer (the reduce_prod from jnp.prod on omni.py:9) and to which jit-decorated function the tracer belongs (ex on omni.py:7).

Side-effects#

Example#

from jax import jit
from jax import random

key = random.PRNGKey(0)

def init():
  global key
  key, subkey = random.split(key)
  return random.normal(subkey, ())

print(init())  # -1.2515389
print(init())  # -0.58665067

init = jit(init)
print(init())  # 0.48648298
print(init())  # 0.48648298  !!

That last call has repeated randomness but no hard error, because we aren’t re-executing the Python. But if we look at key, we see an escaped tracer when omnistaging is on:

print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

Before omnistaging, the random.split call would not be staged out and so we wouldn’t get an escaped tracer. The code would still be buggy in that the jitted function wouldn’t be reproducing the semantics of the original function (because of the repeated use of the same PRNG key), ultimately due to the side effect.

With omnistaging on, if we touch key again, we’ll get an escaped tracer error:

random.normal(key, ())

Error message#

[... full stack trace …]
  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).

Explanation#

The second largest category of omnistaging issues we found had to do with side-effecting code. This code already voided the JAX warranty by transforming effectful functions, but due to pre-omnistaging “trace-time constant folding” behavior, some side effecting functions could nevertheless behave correctly. Omnistaging catches more of these errors.

Solution#

The solution is to identify JAX-transformed functions that rely on side effects, and to rewrite them not to be effectful.

Small numerical differences based on XLA optimizations#

Because with omnistaging more computations are being staged out to XLA, rather than some being executed at trace time, that can have the effect of reordering floating point operations. As a result, we’ve seen numerical behaviors change in a way that causes tests with overly tight tolerances to fail when omnistaging is switched on.

Dependence on JAX internal APIs that changed#

Omnistaging involved some big revisions to JAX’s core code, including removing or changing internal functions. Any code that relies on such internal JAX APIs can break when omnistaging is switched on, either with build errors (from pytype) or runtime errors.

Triggering XLA compile time bugs#

Because omnistaging involves staging out more code to XLA, we’ve seen it trigger pre-existing XLA compile-time bugs on some backends. The best thing to do with these is to report them so we can work with the XLA teams on fixes.