jax.disable_jit#

jax.disable_jit()[source]#

Context manager that disables jit() behavior under its dynamic context.

For debugging it is useful to have a mechanism that disables jit() everywhere in a dynamic context. Note that this not only disables explicit uses of jit by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of body and cond functions passed to higher-level primitives like scan() and while_loop(), JIT used in implementations of jax.numpy functions, and any other case where jit is used within an API’s implementation.

Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a ShapedArray instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign side-effecting operation in a jitted function, like a print:

>>> import jax
>>>
>>> @jax.jit
... def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)>
[5 7 9]

Here y has been abstracted by jit() to a ShapedArray, which represents an array with a fixed shape and type but an arbitrary value. The value of y is also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use the disable_jit() context manager:

>>> import jax
>>>
>>> with jax.disable_jit():
...   print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]