jax.lax.fori_loop

Contents

jax.lax.fori_loop#

jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[source]#

Loop from lower to upper by reduction to jax.lax.while_loop().

The Haskell-like type signature in brief is

fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a

The semantics of fori_loop are given by this Python implementation:

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

As the Python version suggests, setting upper <= lower will produce no iterations. Negative or custom increments are not supported.

Unlike that Python version, fori_loop is implemented in terms of either a call to jax.lax.while_loop() or a call to jax.lax.scan(). If the trip count is static (meaning known at tracing time, perhaps because lower and upper are Python integer literals) then the fori_loop is implemented in terms of scan() and reverse-mode autodiff is supported; otherwise, a while_loop is used and reverse-mode autodiff is not supported. See those functions’ docstrings for more information.

Also unlike the Python analogue, the loop-carried value val must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Note

fori_loop() compiles body_fun, so while it can be combined with jit(), it’s usually unnecessary.

Parameters:
  • lower – an integer representing the loop index lower bound (inclusive)

  • upper – an integer representing the loop index upper bound (exclusive)

  • body_fun – function of type (int, a) -> a.

  • init_val – initial loop carry value of type a.

  • unroll (int | bool | None) – An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. unroll=True) or left completely unrolled (i.e. unroll=False). This argument is only applicable if the loop bounds are statically known.

Returns:

Loop value from the final iteration, of type a.