jax.lax.fori_loop

jax.lax.fori_loop(lower, upper, body_fun, init_val)[source]

Loop from lower to upper by reduction to jax.lax.while_loop().

The type signature in brief is

fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a

The semantics of fori_loop are given by this Python implementation:

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

Unlike that Python version, fori_loop is implemented in terms of either a call to jax.lax.while_loop() or a call to jax.lax.scan(). If the trip count is static (meaning known at tracing time, perhaps because lower and upper` are Python integer literals) then the ``fori_loop is implemented in terms of scan and reverse-mode autodiff is supported; otherwise, a while_loop is used and reverse-mode autodiff is not supported. See those functions’ docstrings for more information.

Also unlike the Python analogue, the loop-carried value val must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters
  • lower – an integer representing the loop index lower bound (inclusive)

  • upper – an integer representing the loop index upper bound (exclusive)

  • body_fun – function of type (int, a) -> a.

  • init_val – initial loop carry value of type a.

Returns

Loop value from the final iteration, of type a.