JAX has limited support for Python concurrency.
It is not permitted to manipulate JAX trace values concurrently from multiple
threads. In other words, while it is permissible to call functions that use JAX
jit()) from multiple threads, you must not use
threading to manipulate JAX values inside the implementation of the function
f that is passed to
jit(). The most likely outcome if you do this
is a mysterious error from JAX.