jax.lax.fori_loopΒΆ

jax.lax.fori_loop(lower, upper, body_fun, init_val)[source]ΒΆ

Loop from lower to upper by reduction to while_loop.

The type signature in brief is

fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a

The semantics of fori_loop are given by this Python implementation:

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

Unlike that Python version, fori_loop is implemented in terms of a call to while_loop. See the docstring for while_loop for more information.

Also unlike the Python analogue, the loop-carried value val must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters:
  • lower – an integer representing the loop index lower bound (inclusive)
  • upper – an integer representing the loop index upper bound (exclusive)
  • body_fun – function of type (int, a) -> a.
  • init_val – initial loop carry value of type a.
Returns:

Loop value from the final iteration, of type a.