jax.lax.fori_loop#
- jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[source]#
Loop from
lower
toupper
by reduction tojax.lax.while_loop()
.The Haskell-like type signature in brief is
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
The semantics of
fori_loop
are given by this Python implementation:def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
As the Python version suggests, setting
upper <= lower
will produce no iterations. Negative or custom increments are not supported.Unlike that Python version,
fori_loop
is implemented in terms of either a call tojax.lax.while_loop()
or a call tojax.lax.scan()
. If the trip count is static (meaning known at tracing time, perhaps becauselower
andupper
are Python integer literals) then thefori_loop
is implemented in terms ofscan()
and reverse-mode autodiff is supported; otherwise, awhile_loop
is used and reverse-mode autodiff is not supported. See those functions’ docstrings for more information.Also unlike the Python analogue, the loop-carried value
val
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typea
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).Note
fori_loop()
compilesbody_fun
, so while it can be combined withjit()
, it’s usually unnecessary.- Parameters:
lower – an integer representing the loop index lower bound (inclusive)
upper – an integer representing the loop index upper bound (exclusive)
body_fun – function of type
(int, a) -> a
.init_val – initial loop carry value of type
a
.unroll (int | bool | None) – An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. unroll=True) or left completely unrolled (i.e. unroll=False). This argument is only applicable if the loop bounds are statically known.
- Returns:
Loop value from the final iteration, of type
a
.