JAX Frequently Asked Questions¶
We are collecting here answers to frequently asked questions. Contributions welcome!
Creating arrays with jax.numpy.array is slower than with numpy.array¶
The following code is relatively fast when using NumPy, and slow when using JAX’s NumPy:
import numpy as np
np.array([0] * int(1e6))
The reason is that in NumPy the numpy.array function is implemented in C, while the jax.numpy.array is implemented in Python, and it needs to iterate over a long list to convert each list element to an array element.
An alternative would be to create the array with original NumPy and then convert it to a JAX array:
from jax import numpy as jnp
jnp.array(np.array([0] * int(1e6)))
jit changes the behavior of my function¶
If you have a Python function that changes behavior after using jit, perhaps your function uses global state, or has side-effects. In the following code, the impure_func uses the global y and has a side-effect due to print:
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
Without jit the output is:
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
and with jit it is:
Inside: 0 Result: 0 Result: 1 Result: 2
For jit the function is executed once using the Python interpreter, at which time the Inside printing happens, and the first value of y is observed. Then the function is compiled and cached, and executed multiple times with different values of x, but with the same first value of y.
Additional reading:
- [JAX - The Sharp Bits: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Pure-functions)
Gradients contain NaN where using where
¶
If you define a function using where
to avoid an undefined value, if you
are not careful you may obtain a NaN for reverse differentiation:
def my_log(x):
return np.where(x > 0., np.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
A short explanation is that during grad
computation the adjoint corresponding
to the undefined np.log(x)
is a NaN
and when it gets accumulated to the
adjoint of the np.where
. The correct way to write such functions is to ensure
that there is a np.where
inside the partially-defined function, to ensure
that the adjoint is always finite:
def safe_for_grad_log(x):
return np.log(np.where(x > 0., x, 1.)
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
The inner np.where
may be needed in addition to the original one, e.g.:
- def my_log_or_y(x, y):
- “”“Return log(x) if x > 0 or y”“” return np.where(x > 0., np.log(np.where(x > 0., x, 1.), y)
Additional reading:
- [Issue: gradients through np.where when one of branches is nan](https://github.com/google/jax/issues/1052#issuecomment-514083352)
- [How to avoid NaN gradients when using
where
](https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf)