# Omnistaging

## Contents

# Omnistaging#

*mattjj@*
*Sept 25 2020*

This is more of an upgrade guide than a design doc.

## Contents#

## tl;dr#

### Whatâ€™s going on?#

A change to JAXâ€™s tracing infrastructure called â€śomnistagingâ€ť (google/jax#3370) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term itâ€™s best to fix the bugs, but omnistaging can also be disabled as a temporary workaround. And weâ€™re happy to help you with fixes!

### How do I know if omnistaging broke my code?#

The easiest way to tell if omnistaging is responsible is to disable omnistaging and see if the issues go away. See the What issues can arise when omnistaging is switched on? section below.

### How can I disable omnistaging for now?#

*Note: this applies to JAX versions 0.2.0 through 0.2.11; omnistaging cannot be
disabled in JAX versions 0.2.12 and higher*

It is temporarily possible to disable omnistaging by

setting the shell environment variable

`JAX_OMNISTAGING`

to something falsey;setting the boolean flag

`jax_omnistaging`

to something falsey if your code parses flags with absl;using this statement near the top of your main file:

```
jax.config.disable_omnistaging()
```

### How do I fix bugs exposed by omnistaging?#

By far the most common issue with omnistaging is using `jax.numpy`

to compute
shape values or other trace-time constants. See the code block below for a quick
example, and for full details along with other issues see the section What
issues can arise when omnistaging is switched
on?.

Instead of this:

```
@jit
def f(x):
input_size = jnp.prod(x.shape)
if input_size > 100:
...
```

do this:

```
import numpy as np
@jit
def f(x):
input_size = np.prod(x.shape)
if input_size > 100:
...
```

Instead of thinking of `jax.numpy`

as a drop-in replacement for `numpy`

, itâ€™s
now better to think of using `jax.numpy`

operations only when you want to perform a
computation on an accelerator (like your GPU).

## What is â€śomnistagingâ€ť and why is it useful?#

Omnistaging is the name for a JAX core upgrade aimed at staging out more
computation from op-by-op Python to XLA, and avoiding any â€śtrace-time constant
foldingâ€ť in `jit`

, `pmap`

, and control flow primitives. As a result, omnistaging
improves JAXâ€™s memory performance (sometimes dramatically) both by reducing
fragmentation during tracing and by producing fewer large compile-time constants
for XLA. It can also improve tracing performance by eliminating op-by-op
execution at tracing time. Further, omnistaging simplifies JAX core internals,
fixing many outstanding bugs and setting the stage for important upcoming
features.

The name â€śomnistagingâ€ť means staging out everything possible.

### Toy example#

JAX transformations like `jit`

and `pmap`

stage out computations to XLA. That
is, we apply them to functions comprising multiple primitive operations so that
rather being executed one at a time from Python the operations are all part of
one end-to-end optimized XLA computation.

But exactly which operations get staged out? Until omnistaging, JAX staged out
computation based on data dependence only. Hereâ€™s an example function, followed
by the XLA HLO program it stages out *before* the omnistaging change:

```
from jax import jit
import jax.numpy as jnp
@jit
def f(x):
y = jnp.add(1, 1)
return x * y
f(3)
```

```
ENTRY jit_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(2)
multiply.4 = s32[] multiply(parameter.1, constant.3)
ROOT tuple.5 = (s32[]) tuple(multiply.4)
}
```

Notice that the `add`

operation is not staged out. Instead, we only see a
multiply.

Hereâ€™s the HLO generated from this function *after* the omnistaging change:

```
ENTRY jit_f.8 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(1)
constant.4 = s32[] constant(1)
add.5 = s32[] add(constant.3, constant.4)
multiply.6 = s32[] multiply(parameter.1, add.5)
ROOT tuple.7 = (s32[]) tuple(multiply.6)
}
```

### Slightly less toy example#

Hereâ€™s a less toy example which can arise in practice when we want to create boolean masks:

```
import jax.numpy as jnp
from jax import lax
@jit
def select_tril(x):
mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
return lax.select(mask, x, jnp.zeros_like(x)) # lax.select is like jnp.where
x = np.arange(12).reshape((3, 4))
select_tril(x)
```

*Before* omnistaging:

```
ENTRY jit_select_tril.8 {
constant.3 = pred[] constant(false)
constant.1 = pred[3,4]{1,0} constant({...})
parameter.2 = s32[3,4]{1,0} parameter(0)
constant.4 = s32[] constant(0)
broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}
```

The `select`

operation is staged out, but the operations for constructing the
constant `mask`

are not. Rather than being staged out, the operations that
construct `mask`

are executed op-by-op at Python tracing time, and XLA only sees
a compile time constant `constant.1`

representing the value of `mask`

. Thatâ€™s
unfortunate, because if we had staged out the operations for constructing
`mask`

, XLA could have fused them into the `select`

and avoided materializing
the result at all. As a result we end up wasting memory with a potentially-large
constant, wasting time dispatching multiple un-fused op-by-op XLA computations,
and potentially even fragmenting memory.

(The `broadcast`

that corresponds to the construction of the zeros array for
`jnp.zeros_like(x)`

is staged out because JAX is lazy about very simple
expressions from google/jax#1668. After
omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)

The reason the creation of `mask`

is not staged out is that, before omnistaging,
`jit`

operates based on data dependence. That is, `jit`

stages out only those
operations in a function that have a data dependence on an argument. Control
flow primitives and `pmap`

behave similarly. In the case of `select_tril`

, the
operations to construct the constant `mask`

do not have a data dependence on the
argument x, so they are not staged out; only the `lax.select`

call has a data
dependence.

With omnistaging all `jax.numpy`

calls in the dynamic context of a
`jit`

-transformed function are staged out to XLA. That is, after omnistaging the
computation XLA sees for `select_tril`

is

```
ENTRY jit_select_tril.16 {
constant.4 = pred[] constant(false)
iota.1 = s32[3]{0} iota(), iota_dimension=0
broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
reshape.7 = s32[3]{0} reshape(broadcast.5)
broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
iota.2 = s32[4]{0} iota(), iota_dimension=0
broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
reshape.9 = s32[4]{0} reshape(broadcast.6)
broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
parameter.3 = s32[3,4]{1,0} parameter(0)
constant.12 = s32[] constant(0)
broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}
```

## What issues can arise when omnistaging is switched on?#

As a consequence of staging out all `jax.numpy`

operations from Python to XLA
when in the dynamic context of a `jit`

or `pmap`

, some code that worked
previously can start raising loud errors. As explained below, these behaviors
were already buggy before omnistaging, but omnistaging makes them into hard
errors.

### Using `jax.numpy`

for shape computations#

#### Example#

```
from jax import jit
import jax.numpy as jnp
@jit
def ex1(x):
size = jnp.prod(jnp.array(x.shape))
return x.reshape((size,))
ex1(jnp.ones((3, 4)))
```

#### Error message#

```
[... full traceback ...]
File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:
operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
from line ex1.py:6 (ex1)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
```

#### Explanation#

With omnistaging, we canâ€™t use `jax.numpy`

for shape computations as in the use
of `jnp.prod`

above because in the dynamic context of a jit function those
operations will be staged out of Python as values to be computed at execution
time, yet we need them to be compile-time (and hence trace-time) constants.

Before omnistaging, this code wouldnâ€™t have raised an error, but it was a common
performance bug: the `jnp.prod`

computation would have been executed on the
device at tracing time, meaning extra compilation, transfers, synchronization,
allocations, and potentially memory fragmentation.

#### Solution#

The solution is simply to use the original `numpy`

for shape calculations like
these. Not only do we avoid the error, but also we keep the computations on the
host (and with lower overheads).

This issue was common enough in code that we tried to make the error
message especially good. In addition to the stack trace showing where an
abstract tracer value caused a problem (the `jnp.reshape`

line in the full stack
trace, on omni.py:10), we also explain why this value became a tracer in the
first place by pointing to the upstream primitive operation that caused it to
become an abstract tracer (the `reduce_prod`

from `jnp.prod`

on omni.py:9) and to
which `jit`

-decorated function the tracer belongs (`ex1`

on omni.py:6).

### Side-effects#

#### Example#

```
from jax import jit
from jax import random
key = random.PRNGKey(0)
def init():
global key
key, subkey = random.split(key)
return random.normal(subkey, ())
print(init()) # -1.2515389
print(init()) # -0.58665067
init = jit(init)
print(init()) # 0.48648298
print(init()) # 0.48648298 !!
```

That last call has repeated randomness but no hard error, because we arenâ€™t
re-executing the Python. But if we look at `key`

, we see an escaped tracer *when
omnistaging is on*:

```
print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
```

Before omnistaging, the `random.split`

call would not be staged out and so we
wouldnâ€™t get an escaped tracer. The code would still be buggy in that the jitted
function wouldnâ€™t be reproducing the semantics of the original function (because
of the repeated use of the same PRNG key), ultimately due to the side effect.

With omnistaging on, if we touch `key`

again, weâ€™ll get an escaped tracer error:

```
random.normal(key, ())
```

#### Error message#

```
[... full stack trace â€¦]
File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
```

#### Explanation#

The second largest category of omnistaging issues we found had to do with side-effecting code. This code already voided the JAX warranty by transforming effectful functions, but due to pre-omnistaging â€śtrace-time constant foldingâ€ť behavior, some side effecting functions could nevertheless behave correctly. Omnistaging catches more of these errors.

#### Solution#

The solution is to identify JAX-transformed functions that rely on side effects, and to rewrite them not to be effectful.

### Small numerical differences based on XLA optimizations#

Because with omnistaging more computations are being staged out to XLA, rather than some being executed at trace time, that can have the effect of reordering floating point operations. As a result, weâ€™ve seen numerical behaviors change in a way that causes tests with overly tight tolerances to fail when omnistaging is switched on.

### Dependence on JAX internal APIs that changed#

Omnistaging involved some big revisions to JAXâ€™s core code, including removing or changing internal functions. Any code that relies on such internal JAX APIs can break when omnistaging is switched on, either with build errors (from pytype) or runtime errors.

### Triggering XLA compile time bugs#

Because omnistaging involves staging out more code to XLA, weâ€™ve seen it trigger pre-existing XLA compile-time bugs on some backends. The best thing to do with these is to report them so we can work with the XLA teams on fixes.