jax.lax.fori_loop

jax.lax.fori_loop(lower, upper, body_fun, init_val)[source]

Loop from lower to upper by reduction to while_loop.

The type signature in brief is

fori_loop :: Int -> Int -> ((int, a) -> a) -> a -> a

The semantics of fori_loop are given by this Python implementation:

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

Unlike that Python version, fori_loop is implemented in terms of a call to while_loop. See the docstring for while_loop for more information.

Parameters:
  • lower – an integer representing the loop index lower bound (inclusive)
  • upper – an integer representing the loop index upper bound (exclusive)
  • body_fun – function of type (int, a) -> a.
  • init_val – initial loop carry value of type a.
Returns:

Loop value from the final iteration, of type a.