jax.experimental.loops module

Loops is an experimental module for syntactic sugar for loops and control-flow.

The current implementation should convert loops correctly to JAX internal representation, and most transformations should work (see below), but we have not yet fine-tuned the performance of the resulting XLA compilation!

By default, loops and control-flow in JAX are executed and inlined during tracing. For example, in the following code the for loop is unrolled during JAX tracing:

arr = np.zeros(5)
for i in range(arr.shape[0]):
  arr[i] += 2.
  if i % 2 == 0:
    arr[i] += 1.

In order to capture the structured control-flow one can use the higher-order JAX operations, which require you to express the body of the loops and conditionals as functions, and the array updates using a functional style that returns an updated array, e.g.:

arr = np.zeros(5)
def loop_body(i, acc_arr):
  arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
  return lax.cond(i % 2 == 0,
                  arr1,
                  lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
                  arr1,
                  lambda arr1: arr1)
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)

This API quickly gets unreadable with deeper nested loops. With the utilities in this module you can write loops and conditionals that look closer to plain Python, as long as you keep the loop-carried state in a special loops.scope object and use for loops over special scope.range iterators:

from jax.experimental import loops
with loops.Scope() as s:
  s.arr = np.zeros(5)  # Create the mutable state of the loop as `scope` fields.
  for i in s.range(s.arr.shape[0]):
    s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
    for _ in s.cond_range(i % 2 == 0):  # Conditionals as loops with 0 or 1 iterations
      s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)

Loops constructed with range must have literal constant bounds. If you need loops with dynamic bounds, you can use the more general while_range iterator. However, in that case that grad transformation is not supported:

s.idx = start
for _ in s.while_range(lambda: s.idx < end):
  s.idx += 1

Notes

  • Loops and conditionals to be functionalized can appear only inside scopes constructed with loops.Scope and they must use one of the Scope.range iterators. All other loops are unrolled during tracing, as usual in JAX.

  • Only scope data (stored in fields of the scope object) is functionalized. All other state, e.g., in other Python variables, will not be considered as being part of the loop output. All references to the mutable state should be through the scope, e.g., s.arr.

  • The scope fields can be pytrees, and can themselves be mutable data structures.

  • Conceptually, this model is still “functional” in the sense that a loop over a Scope.range behaves as a function whose input and output is the scope data.

  • Scopes should be passed down to callees that need to use loop functionalization, or they may be nested.

  • The programming model is that the loop body over a scope.range is traced only once, using abstract shape values, similar to how JAX traces function bodies.

Restrictions:
  • The tracing of the loop body should not exit prematurely with return, exception, break. This would be detected and reported as errors when we encounter unnested scopes.

  • The loop index variable should not be used after the loop. Similarly, one should not use outside the loop data computed in the loop body, except data stored in fields of the scope object.

  • No new mutable state can be created inside a loop to be functionalized. All mutable state must be created outside all loops and conditionals.

  • Once the loop starts all updates to loop state must be with new values of the same abstract values as the values on loop start.

  • For a while loop, the conditional function is not allowed to modify the scope state. This is a checked error. Also, for while loops the grad transformation does not work. An alternative that allows grad is a bounded loop (range).

Transformations:
  • All transformations are supported, except grad is not supported for Scope.while_range loops.

  • vmap is very useful for such loops because it pushes more work into the inner-loops, which should help performance for accelerators.

For usage example, see tests/loops_test.py.

class jax.experimental.loops.Scope[source]

Bases: object

A scope context manager to keep the state of loop bodies for functionalization.

Usage:

with Scope() as s:
  s.data = 0.
  for i in s.range(5):
    s.data += 1.
  return s.data
cond_range(pred)[source]

Creates a conditional iterator with 0 or 1 iterations based on the boolean.

The body is converted to a lax.cond. All JAX transformations work.

Usage:

for _ in scope.cond_range(s.field < 0.):
  s.field = - s.field
range(first, second=None, third=None)[source]

Creates an iterator for bounded iterations to be functionalized.

The body is converted to a lax.scan, for which all JAX transformations work. The first, second, and third arguments must be integer literals.

Usage:

range(5)  # start=0, end=5, step=1
range(1, 5)  # start=1, end=5, step=1
range(1, 5, 2)  # start=1, end=5, step=2

s.out = 1.
for i in scope.range(5):
  s.out += 1.
start_subtrace()[source]

Starts a nested trace, returns the Trace object.

while_range(cond_func)[source]

Creates an iterator that continues as long as cond_func returns true.

The body is converted to a lax.while_loop. The grad transformation does not work.

Usage:

for _ in scope.while_range(lambda: s.loss > 1.e-5):
  s.loss = loss(...)
Parameters

cond_func – a lambda with no arguments, the condition for the “while”.