jax.lax.while_loopΒΆ
-
jax.lax.
while_loop
(cond_fun, body_fun, init_val)[source]ΒΆ Call
body_fun
repeatedly in a loop whilecond_fun
is True.The type signature in brief is
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of
while_loop
are given by this Python implementation:def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
Unlike that Python version,
while_loop
is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an@jit
function are unrolled, leading to large XLA computations.Also unlike the Python analogue, the loop-carried value
val
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typea
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).Another difference from using Python-native loop constructs is that
while_loop
is not reverse-mode differentiable because XLA computations require static bounds on memory requirements.- Parameters
- Return type
~T
- Returns
The output from the final iteration of body_fun, of type
a
.