Source code for jax._src.lax.control_flow.loops

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the loop primitives."""
from functools import partial
import itertools
import operator

from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar

import jax
import weakref
from jax import core
from jax import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import jax._src.pretty_printer as pp
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
                           tree_map)
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (
    cache,
    extend_name_stack,
    partition_list,
    safe_map,
    safe_zip,
    split_list,
    unzip2,
    weakref_lru_cache,
    )
import numpy as np

from jax._src.lax.control_flow.common import (
    _abstractify,
    _avals_short,
    _check_tree_and_avals,
    _initial_style_jaxpr,
    _make_closed_jaxpr,
    _prune_zeros,
    _typecheck_param,
    allowed_effects,
    )

_map = safe_map
zip = safe_zip

T = TypeVar('T')
Array = Any
BooleanNumeric = Any  # A bool, or a Boolean array.

### Helper functions

def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
  """Promote weakly-typed in_vals to be compatible with out_avals.

  Args:
    in_vals : flattened list of input values.
    in_avals : corresponding list of avals.
    out_avals : list of target output avals.
  Returns:
    in_vals_new : flattened list of modified in_vals with no weak types.
    changed : bool; true if in_vals required modification.
  """
  if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
    # Calling function is responsible for catching this.
    return in_vals, False
  weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
                    if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
  if not weak_mismatches:
    return in_vals, False
  for i in weak_mismatches:
    new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
    in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
  return in_vals, True


### scan

Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')

[docs]@api_boundary def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, length: Optional[int] = None, reverse: bool = False, unroll: int = 1) -> Tuple[Carry, Y]: """Scan a function over leading array axes while carrying along state. The `Haskell-like type signature`_ in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When ``a`` is an array type or None, and ``b`` is an array type, the semantics of ``scan`` are given roughly by this Python implementation:: def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. (None is actually an empty pytree.) Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type ``c`` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves). .. note:: :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by ``f``. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. length: optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in ``xs`` (but can be used to perform scans where no input ``xs`` are needed). reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both ``xs`` and in ``ys``. unroll: optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ if not callable(f): raise TypeError("lax.scan: f argument should be a callable.") xs_flat, xs_tree = tree_flatten(xs) try: lengths = [x.shape[0] for x in xs_flat] except AttributeError as err: msg = "scan got value with no leading axis to scan over: {}." raise ValueError( msg.format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err if length is not None: length = int(length) if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " "leading axis sizes {}.") raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) else: unique_lengths = set(lengths) if len(unique_lengths) > 1: msg = "scan got values with different leading axis sizes: {}." raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) elif len(unique_lengths) == 0: msg = "scan got no values to scan over and `length` not provided." raise ValueError(msg) else: length, = unique_lengths if config.jax_disable_jit: if length == 0: raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") carry = init ys = [] maybe_reversed = reversed if reverse else lambda x: x for i in maybe_reversed(range(length)): xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat] carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) ys.append(y) stack = lambda *ys: jax.numpy.stack(ys) stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] def _create_jaxpr(init): init_flat, init_tree = tree_flatten(init) in_flat, in_tree = tree_flatten((init, xs)) carry_avals = tuple(_map(_abstractify, init_flat)) jaxpr, consts, out_tree = _initial_style_jaxpr( f, in_tree, (*carry_avals, *x_avals), "scan") out_tree_children = out_tree.children() if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children # The carry input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that # may not match the output despite being compatible by virtue of their weak type. # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) if changed: new_init = tree_unflatten(init_tree, new_init_flat) init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init) in_flat, jaxpr, consts, out_tree, out_tree_children = rest _check_tree_and_avals("scan carry output and input", # Extract the subtree and avals for the first element of the return tuple out_tree_children[0], carry_avals_out, init_tree, carry_avals) disallowed_effects = jaxpr.effects - allowed_effects if disallowed_effects: raise NotImplementedError( f'Effects not supported in `scan`: {disallowed_effects}') out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=len(init_flat), linear=(False,) * (len(consts) + len(in_flat)), unroll=unroll) return tree_unflatten(out_tree, out)
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, f_impl, x_avals, y_avals): consts, init, xs = split_list(args, [num_consts, num_carry]) carry = init ys = [] for i in range(length): i_ = length - i - 1 if reverse else i x = _map(partial(_index_array, i_), x_avals, xs) out = f_impl(*consts, *carry, *x) carry, y = split_list(out, [num_carry]) ys.append(y) ys = list(reversed(ys)) if reverse else ys ys = list(zip(*ys)) ys = _map(_stack, y_avals, ys) return (*carry, *ys) def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear, f_impl, x_avals, y_avals): consts, init, xs = split_list(args, [num_consts, num_carry]) def cond_fun(vals): i, *_ = vals return i < length def body_fun(vals): [i], carry, ys = split_list(vals, [1, num_carry]) i_ = length - i - 1 if reverse else i x = _map(partial(_dynamic_index_array, i_), x_avals, xs) out_flat = f_impl(*consts, *carry, *x) carry_out, y_updates = split_list(out_flat, [num_carry]) ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) return [i + 1] + carry_out + ys_out ys_init = _map(partial(_empty_array, length), y_avals) if length == 0: return init + ys_init else: init_val = [lax._const(length, 0)] + init + ys_init _, *outs = while_loop(cond_fun, body_fun, init_val) return outs def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry, linear, block_length, f_impl, x_avals, y_avals): consts, init, xs = split_list(args, [num_consts, num_carry]) num_blocks, rem = divmod(length, block_length) assert rem == 0 partition = partial(_partition_leading, num_blocks, block_length) xs_block = _map(partition, x_avals, xs) prepend_aval = partial(_prepend_dim_to_aval, block_length) x_block_avals = _map(prepend_aval, x_avals) y_block_avals = _map(prepend_aval, y_avals) f_impl_block = partial( _scan_impl_unrolled, reverse=reverse, length=block_length, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) outs = _scan_impl_loop( *consts, *init, *xs_block, reverse=reverse, length=num_blocks, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals) carry, ys_blocks = split_list(outs, [num_carry]) combine = partial(_combine_leading, num_blocks, block_length) ys = _map(combine, y_avals, ys_blocks) return (*carry, *ys) def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) f_impl = core.jaxpr_as_fun(jaxpr) if unroll == 1: return _scan_impl_loop( *args, reverse=reverse, length=length, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) consts, init, xs = split_list(args, [num_consts, num_carry]) num_blocks, rem = divmod(length, unroll) length_div = num_blocks * unroll if rem > 0: if reverse: split = partial(_split_leading_dim, rem) xs_rem, xs = unzip2(_map(split, x_avals, xs)) else: split = partial(_split_leading_dim, length_div) xs, xs_rem = unzip2(_map(split, x_avals, xs)) outs = _scan_impl_block_unrolled( *consts, *init, *xs, reverse=reverse, length=length_div, num_consts=num_consts, num_carry=num_carry, linear=linear, block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) carry, ys = split_list(outs, [num_carry]) if rem > 0: outs = _scan_impl_unrolled( *consts, *carry, *xs_rem, reverse=reverse, length=rem, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) carry, ys_rem = split_list(outs, [num_carry]) if reverse: ys = _map(_concatenate, y_avals, ys_rem, ys) else: ys = _map(_concatenate, y_avals, ys, ys_rem) return (*carry, *ys) def _stack(aval, vals): vals = [lax.expand_dims(x, (0,)) for x in vals] return lax.concatenate(vals, 0) def _concatenate(aval, x1, x2): return lax.concatenate([x1, x2], 0) def _split_leading_dim(i, aval, x): assert x.ndim >= 1 return (slicing.slice_in_dim(x, 0, i), slicing.slice_in_dim(x, i, x.shape[0])) def _dynamic_index_array(i, aval, x): return slicing.dynamic_index_in_dim(x, i, keepdims=False) def _index_array(i, aval, x): return slicing.index_in_dim(x, i, keepdims=False) def _empty_array(sz, aval): return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape)) def _update_array(i, aval, xs, x): return slicing.dynamic_update_index_in_dim(xs, x, i, 0) def _partition_leading(sz0, sz1, aval, x): assert x.ndim >= 1 assert x.shape[0] == sz0 * sz1 return lax.reshape(x, (sz0, sz1, *x.shape[1:])) def _combine_leading(sz0, sz1, aval, x): assert x.ndim >= 2 assert x.shape[0] == sz0 assert x.shape[1] == sz1 return lax.collapse(x, 0, 2) def _prepend_dim_to_aval(sz, aval): return core.unmapped_aval(sz, core.no_axis_name, 0, aval) def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) return carry_avals + ys_avals, jaxpr.effects def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry nonzeros = [type(t) is not ad_util.Zero for t in tangents] const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) # Fixpoint computation of which carry are not ad.zero: either # non-zero from init, or the carry out is non-zero. Each iteration promotes # at least one carry to non-zero. We need at most len(carry) iterations, # but we need one last iteration to prepare the jaxpr based on the final # carry_nz. carry_nz = init_nz for _ in range(1 + len(carry_nz)): nonzeros = const_nz + carry_nz + xs_nz jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr( jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:] if carry_nz_out == carry_nz: break else: carry_nz = _map(operator.or_, carry_nz, carry_nz_out) else: assert False, "Fixpoint not reached" tangents = [ad.instantiate_zeros(t) if nz else t for t, nz in zip(tangents, nonzeros)] consts, init, xs = split_list(primals, [num_consts, num_carry]) all_tangents = split_list(tangents, [num_consts, num_carry]) consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) jaxpr_jvp_rearranged = ad.rearrange_binders( jaxpr_jvp, [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot) + init_linear + [True] * len(init_dot) + xs_linear + [True] * len(xs_dot)) out_flat = scan_p.bind( *(consts + consts_dot + init + init_dot + xs + xs_dot), reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged, num_consts=num_consts + len(consts_dot), num_carry=num_carry + len(init_dot), linear=jaxpr_jvp_linear, unroll=unroll) carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out_iter = iter(carry_dot + ys_dot) tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): num_ys = len(jaxpr.out_avals) - num_carry unknowns = [not t.pval.is_known() for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) # Fixpoint computation of which carry elements are unknown. Each iteration # promotes at least one carry to unknown. We need at most len(carry) # iterations, but we need one last iteration to prepare the jaxpr based on the # final carry_uk. carry_uk = init_uk for _ in range(1 + len(carry_uk)): unknowns = const_uk + carry_uk + xs_uk jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits( jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) if carry_uk_out == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" num_res = len(res_avals) del res_avals, carry_uk_out # Instantiate those inputs which must be treated as unknown from the fixpoint. tracers = [trace.instantiate_const(t) if uk else t for t, uk in zip(tracers, unknowns)] # The residual inputs and outputs of the jaxprs produced haven't yet been # adapted to the scan calling convention; in particular, jaxpr_known has its # residual outputs all at the end, meaning they're extensive outputs (which is # fully general but may be wasteful for residuals which are loop-invariant) # while jaxpr_unknown has its corresponding residual inputs at the front (just # as a convention with partial_eval_jaxpr_nounits), making them constant # inputs. To make them consistent, we move the residual inputs on # jaxpr_unknown to the end, even though we may move some back in the sequel. jaxpr_unknown = pe.move_binders_to_back( jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) # At this point, all residuals are treated as extensive outputs of jaxpr_known # (and extensive inputs to jaxpr_unknown). But residuals that are loop- # invariant can be hoisted out of the scan, rather than letting them get # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't # want to broadcast the matrix!). So, outside the loop we perform a partial # evaluation with known 'const' inputs (but all other inputs unknown). const_pvals = [pe.PartialVal.known(t.pval.get_known()) for t in tracers[:num_consts] if t.pval.is_known()] other_pvals = [pe.PartialVal.unknown(aval) for aval in jaxpr_known.in_avals[len(const_pvals):]] with source_info_util.reset_name_stack(): jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals, instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) # The above trace_to_jaxpr_nounits call computed loop-invariant residuals # (known values in invar_pvals_out) and also computed loop-invariant values # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the # previous consts). We need to collect the computed inteisive residuals, and # move corresponding intensive residual binders in jaxpr_unknown to the front. res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] jaxpr_unknown = pe.move_binders_to_front( jaxpr_unknown, [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. # As another optimization, for any extensive inputs that are just forwarded to # extensive outputs, to avoid a copy (which would be looping over # dynamic-update-slice) we'd rather forward the input tracer/value. That means # pruning some outputs from jaxpr_known here, and updating `out_flat` below. fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) # Prune fwds_known to include only extensive input to extensive output. fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and in_idx is not None and in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) else None for out_idx, in_idx in enumerate(fwds_known)] # Drop any extensive output we can instead get by forwarding an input. # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None] jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) del jaxpr_known_ # We use `fwds_known` below when forming the output of scanning jaxpr_known. # Run the known part of the scan (if it has any outputs or effects). known_inputs = (list(jaxpr_known_consts) + [t.pval.get_known() for t in tracers[num_consts:] if t.pval.is_known()]) if not jaxpr_known.out_avals and not jaxpr_known.effects: out_known = [] else: linear_known = [False] * len(known_inputs) # conservative! out_known = scan_p.bind( *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), linear=tuple(linear_known), unroll=unroll) del linear_known # Complete the known output by filling in forwarded values using fwds_known. out_known_iter = iter(out_known) out_known = [next(out_known_iter) if f is None else _maybe_put(known_inputs[f]) for f in fwds_known] assert next(out_known_iter, None) is None del known_inputs, out_known_iter # Split known outputs from residuals. out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) assert len(intensive_res) + len(extensive_res) == num_res # Create input tracers for jaxpr_unknown bind. unknown_inputs = [t for t in tracers if not t.pval.is_known()] intensive_res = _map(trace.new_instantiated_const, intensive_res) extensive_res = _map(trace.new_instantiated_const, extensive_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in itertools.chain(carry_avals, ys_avals)] del carry_avals, y_avals # Create equation. linear_unknown = tuple([False] * len(intensive_res) + [l for l, uk in zip(linear, unknowns) if uk] + [False] * len(extensive_res)) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(out_tracers) == len(jaxpr_unknown.out_avals) eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, num_consts=len(intensive_res) + sum(const_uk), num_carry=sum(carry_uk)), jaxpr_unknown.effects, source) for t in out_tracers: t.recipe = eqn # Merge known and unknown outputs into final result. return util.merge_lists(out_uk, out_known, out_tracers) def _maybe_put(x): if isinstance(x, np.ndarray): return jax.device_put(x, jax.devices('cpu')[0]) else: return x def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): # we've only implemented transposing scans with specific lin/nonlin patterns consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) num_ires = len(consts_lin) - sum(consts_lin) num_eres = len(xs_lin) - sum(xs_lin) if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires): raise NotImplementedError if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: raise NotImplementedError if not all(init_lin): pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 consts, _, xs = split_list(args, [num_consts, num_carry]) ires, _ = split_list(consts, [num_ires]) _, eres = split_list(xs, [sum(xs_lin)]) assert not any(ad.is_undefined_primal(r) for r in ires) assert not any(ad.is_undefined_primal(r) for r in eres) carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) ct_carry, ct_ys = split_list(cts, [num_carry]) ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys) ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) jaxpr_trans = _transpose_scan_jaxpr( num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes) linear_trans = ([False] * num_ires + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) outs = scan_p.bind( *(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse, length=length, jaxpr=jaxpr_trans, num_consts=num_ires, num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans), unroll=unroll) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres # transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) # -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes): num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 # TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals # if an axis isn't reduced res1_avals, c_avals, a_avals, res2_avals = split_list( jaxpr.in_avals, [num_res1, num_c, num_a]) num_b = len(jaxpr.out_avals) b_avals = list(jaxpr.out_avals) @lu.wrap_init def transposed(*res1_cbar_bbar_res2): res1, c_bar, b_bar, res2 = split_list( res1_cbar_bbar_res2, [num_res1, num_c, num_b]) primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, b_bar) _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) c_bar = _map(ad.instantiate_zeros_aval, c_avals, _map(ad.add_tangents, c_bar, new_c_bar)) return c_bar + a_bar return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals) def _scan_batching_rule(axis_size, axis_name, main_type, args, dims, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): num_ys = len(jaxpr.out_avals) - num_carry orig_batched = [d is not batching.not_mapped for d in dims] const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry]) # Fixpoint computation of which carry are batched: either # batched from init, or the carry out is batched. Each iteration promotes # at least one carry to batched. We need at most len(carry) iterations, # but we need one last iteration to prepare the jaxpr based on the final # carry_batched. carry_batched = init_batched for _ in range(1 + len(carry_batched)): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, axis_size, batched, instantiate=carry_batched + [False] * num_ys, axis_name=axis_name, main_type=main_type) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] if carry_batched_out == carry_batched: break else: carry_batched = _map(operator.or_, carry_batched, carry_batched_out) else: assert False, "Fixpoint not reached" consts, init, xs = split_list(args, [num_consts, num_carry]) consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims)] new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched)] new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 else x for x, d in zip(xs, xs_bdims)] new_args = new_consts + new_init + new_xs outs = scan_p.bind( *new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched, num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll) carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] return outs, carry_bdims + ys_bdims def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): fun = core.jaxpr_as_fun(jaxpr) @lu.wrap_init def masked(*args): [dynamic_length], consts, [i], carry, xs = split_list( args, [1, num_consts, 1, num_carry]) out = fun(*(consts + carry + xs)) new_carry, ys = split_list(out, [num_carry]) new_carry = [lax.select(i < dynamic_length, new_c, c) for new_c, c in zip(new_carry, carry)] return [i + 1] + new_carry + ys aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_)) const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals) def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) return scan_p.bind(*args, jaxpr=padded_jaxpr, **params) def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn ) -> Tuple[List[bool], core.JaxprEqn]: jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_xs = len(jaxpr.in_avals) - num_consts - num_carry used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry]) for i in range(1 + num_carry): used_outputs = used_carry_out + used_extensive_out jaxpr_dce, used_inputs = pe.dce_jaxpr( jaxpr.jaxpr, used_outputs, instantiate=[False] * num_consts + used_carry_out + [False] * num_xs) used_consts, used_carry_in, used_extensive_in = \ split_list(used_inputs, [num_consts, num_carry]) if list(used_carry_in) == list(used_carry_out): break else: used_carry_out = _map(operator.or_, used_carry_out, used_carry_in) else: assert False, "Fixpoint not reached" if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr) new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u] new_params = dict(eqn.params, num_consts=sum(used_consts), num_carry=sum(used_carry_in), linear=tuple(new_linear), jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts)) # TODO(mattjj,sharadmv): don't assume effects are never DCE'd? new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, eqn.effects, eqn.source_info) assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals ) assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals) return used_inputs, new_eqn # TODO(mattjj): de-duplicate code with _scan_partial_eval def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_ys = len(jaxpr.out_avals) - num_carry # Fixpoint (trivial on 'inst_in', since we might as well make all inputs # available as DCE can subsequently prune any unused ones) const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry]) for _ in range(1 + len(carry_uk)): unks_in = const_uk + carry_uk + xs_uk jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr.jaxpr, in_unknowns=unks_in, in_inst=True, ensure_out_unknowns=carry_uk + [False] * num_ys, ensure_out_inst=True, saveable=saveable) carry_uk_out, ys_uk = split_list(unks_out, [num_carry]) if carry_uk_out == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts) jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts) # Move all residual binders to the back of jaxpr_staged so they're extensive. # TODO(mattjj): make jaxpr_staged only take instantiated inputs res_avals = jaxpr_staged.in_avals[:num_res] jaxpr_staged = pe.move_binders_to_back( jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals)) # Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to # passing in_inst argument to partial_eval_jaxpr_custom above). new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is core.Var and not inst] inst_in = [True] * len(inst_in) # As an optimization, hoist loop-invariant residuals out of the loop rather # than using extensive outputs for them. See _scan_partial_eval for comments. num_const_known = len(const_uk) - sum(const_uk) num_carry_known = len(carry_uk) - sum(carry_uk) num_xs_known = len( xs_uk) - sum( xs_uk) jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \ pe.partial_eval_jaxpr_nounits( jaxpr_known, [False] * num_const_known + [True] * (num_carry_known + num_xs_known), [True] * (len(unks_out) - sum(unks_out)) + [False] * num_res) # jaxpr_known_hoist produces intensive residuals followed by the constants for # jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts. _, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res]) jaxpr_staged = pe.move_binders_to_front( jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res)) num_intensive_res = len(loop_dep_res) - sum(loop_dep_res) del loop_dep, num_carry_known, num_xs_known, const_uk # Create residual variables. intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals) ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a) for a in ext_avals_mapped] newvar = core.gensym() intensive_res = _map(newvar, intensive_avals) extensive_res = _map(newvar, ext_avals) # Create known eqn, which is a call_p combining evaluation of # jaxpr_known_hoist and a scan of jaxpr_known_loop. ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(unks_out, eqn.outvars) # jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist # (corresponding to consts_known_lp_avals) followed by known carry and xs. linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk] _, linear_known_ = split_list(linear_known_, [num_const_known]) linear_known = [False] * len(consts_known_lp_avals) + linear_known_ params_known = dict(eqn.params, jaxpr=jaxpr_known_loop, num_consts=len(consts_known_lp_avals), num_carry=len(carry_uk)-sum(carry_uk), linear=tuple(linear_known)) @lu.wrap_init def known(*ins_known): consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known]) out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist) intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res]) out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known) return [*intensive_res, *out_loop] call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( known, [v.aval for v in ins_known]) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) eqn_known = pe.new_jaxpr_eqn( ins_known, [*intensive_res, *out_binders_known, *extensive_res], core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects, eqn.source_info) # Create the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) + [False] * len(extensive_res)) params_staged = dict(eqn.params, jaxpr=jaxpr_staged, num_consts=len(intensive_res) + eqn.params['num_consts'], linear=tuple(linear_staged)) eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res], out_binders_staged, eqn.primitive, params_staged, jaxpr_staged.effects, eqn.source_info) new_vars = [*new_inst, *intensive_res, *extensive_res] return eqn_known, eqn_staged, unks_out, inst_out, new_vars def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): avals = [x.aval for x in in_atoms] tc = partial(_typecheck_param, 'scan') tc(reverse, 'reverse', 'bool', type(reverse) is bool) tc(num_consts, 'num_consts', 'non-negative int', type(num_consts) is int and num_consts >= 0) tc(num_carry, 'num_carry', 'non-negative int', type(num_carry) is int and num_carry >= 0) tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr) tc(linear, 'linear', 'tuple of bool', type(linear) is tuple and all(type(x) is bool for x in linear)) tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) tc(length, 'length', 'non-negative int', type(length) is int and length >= 0) if len(linear) != len(avals): raise core.JaxprTypeError( f'scan param linear has length {len(linear)} for {len(avals)} operands') const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry]) const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list( jaxpr.in_avals, [num_consts, num_carry]) carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry]) x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a) for a in y_avals_mapped] if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): raise core.JaxprTypeError( f'scan input carry input and output types mismatch: ' f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}') if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)): raise core.JaxprTypeError( f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n' f'called with consts of type\n{_avals_short(const_avals)}') if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)): raise core.JaxprTypeError( f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n' f'called with initial carry of type\n{_avals_short(init_avals)}') if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)): raise core.JaxprTypeError( f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n' f'called with sequence of type\n{_avals_short(x_avals)}') return [*init_avals, *y_avals], jaxpr.effects def _scan_pp_rule(eqn, context, settings): printed_params = dict(eqn.params) del printed_params['linear'] if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars): del printed_params['length'] if printed_params['unroll'] == 1: del printed_params['unroll'] if printed_params['num_carry'] == 0: del printed_params['num_carry'] if printed_params['num_consts'] == 0: del printed_params['num_consts'] if not printed_params['reverse']: del printed_params['reverse'] lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes) rhs = [pp.text(eqn.primitive.name), core.pp_kv_pairs(sorted(printed_params.items()), context, settings), pp.text(" ") + core.pp_vars(eqn.invars, context)] annotation = (source_info_util.summarize(eqn.source_info) if settings.source_info else None) return [lhs, pp.text(" = ", annotation=annotation), *rhs] def scan_bind(*args, **params): if config.jax_enable_checks: avals = _map(core.get_aval, args) in_atoms = [core.Var(0, '', a) for a in avals] # dummies _scan_typecheck(True, *in_atoms, **params) core.check_jaxpr(params['jaxpr'].jaxpr) return core.AxisPrimitive.bind(scan_p, *args, **params) scan_p = core.AxisPrimitive("scan") scan_p.multiple_results = True scan_p.def_custom_bind(scan_bind) scan_p.def_impl(partial(xla.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp ad.reducing_transposes[scan_p] = _scan_transpose pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, mlir.lower_fun(_scan_impl, multiple_results=True)) batching.axis_primitive_batchers[scan_p] = _scan_batching_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom pe.padding_rules[scan_p] = _scan_padding_rule pe.dce_rules[scan_p] = _scan_dce_rule # TODO(mattjj,frostig): un-comment this pp rule # core.pp_eqn_rules[scan_p] = _scan_pp_rule ### while_loop
[docs]@api_boundary def while_loop(cond_fun: Callable[[T], BooleanNumeric], body_fun: Callable[[T], T], init_val: T) -> T: """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The `Haskell-like type signature`_ in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Also unlike the Python analogue, the loop-carried value ``val`` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type ``a`` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves). Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. .. note:: :py:func:`while_loop` compiles ``cond_fun`` and ``body_fun``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ if not (callable(body_fun) and callable(cond_fun)): raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.") if config.jax_disable_jit: try: val = init_val while cond_fun(val): val = body_fun(val) return val except core.ConcretizationTypeError: # Can't run this while_loop in Python (e.g. because there's a vmap # transformation on it), so we fall back to the primitive version. pass def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(_abstractify, init_vals)) cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_fun, in_tree, init_avals, "while_cond") body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_fun, in_tree, init_avals, "while_loop") if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) pred_aval = cond_jaxpr.out_avals[0] if (not isinstance(pred_aval, ShapedArray) or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)): msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree # The body input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that # may not match the output despite being compatible by virtue of their weak type. # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val) new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals) if changed: new_init_val, = tree_unflatten(in_tree, new_init_vals) init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val) cond_jaxpr, cond_consts, body_consts, body_tree = rest in_tree_children = in_tree.children() assert len(in_tree_children) == 1 _check_tree_and_avals("body_fun output and input", body_tree, body_jaxpr.out_avals, in_tree_children[0], init_avals) effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) disallowed_effects = effects - allowed_effects if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) return tree_unflatten(body_tree, outs)
def _while_loop_abstract_eval(*args, cond_jaxpr, body_jaxpr, **kwargs): del args, kwargs joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) disallowed_effects = joined_effects - allowed_effects if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects def _while_loop_batching_rule(axis_size, axis_name, main_type, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts]) cconst_dims, bconst_dims, init_dims = split_list(dims, [cond_nconsts, body_nconsts]) carry_bat = init_bat # Fixpoint computation of which carry are batched: either # batched from init, or the carry out is batched. Each iteration promotes # at least one carry to batched. We need at most len(carry) iterations to # reach a fixpoint. for _ in range(1 + len(carry_bat)): _, carry_bat_out = batching.batch_jaxpr( body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat, axis_name=axis_name, main_type=main_type) if carry_bat == carry_bat_out: break carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out) else: assert False, "Fixpoint not reached" # Knowing how the carry is batched now, we can determine if the predicate is # batched. _, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False, axis_name=axis_name, main_type=main_type) if pred_bat: # If the predicate is batched, we have to batch *all* of the carry # regardless of if the body needs it. carry_bat = [True] * len(carry_bat) carry_dims = [0] * len(carry_bat) body_jaxpr_batched, _ = batching.batch_jaxpr_axes( body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, axis_name=axis_name, main_type=main_type) cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( cond_jaxpr, axis_size, cconst_dims + carry_dims, [0], axis_name=axis_name, main_type=main_type) else: # If the predicate is not batched, we can look at the `cond_jaxpr`'s out # shape to determine the rank of the predicate. From this rank we pick the # dims of the carry to be batched to ensure that the predicate shape is a # prefix of the carry in and out shapes. We can then batch the `body_jaxpr` # according to these new batch dims. cond_rank = len(cond_jaxpr.out_avals[0].shape) carry_dims = [cond_rank if b else None for b in carry_bat] body_jaxpr_batched, _ = batching.batch_jaxpr_axes( body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, axis_name=axis_name, main_type=main_type) # Now we need to rebatch the `cond_jaxpr` according to the new dims of the # carry. cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,), axis_name=axis_name, main_type=main_type) # To prepare the `init` to the `while_p`, we broadcast values if they are # unbatched and need to have an out axis. If their current batch axis does not # match the one it needs to be for the translation rule to work, we move it # into place. new_init = [] for x, old_axis, new_axis in zip(init, init_dims, carry_dims): if old_axis is batching.not_mapped and new_axis is not batching.not_mapped: new_init.append(batching.broadcast(x, axis_size, new_axis)) elif old_axis is batching.not_mapped and new_axis is batching.not_mapped: new_init.append(x) else: assert new_axis is not batching.not_mapped new_init.append(batching.moveaxis(x, old_axis, new_axis)) outs = while_p.bind(*(cconsts + bconsts + new_init), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) return outs, carry_dims def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): nonzeros = [type(t) is not ad_util.Zero for t in tangents] cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts]) carry_nz = init_nz for _ in range(1 + len(carry_nz)): body_nonzeros = bconst_nz + carry_nz body_jvp, nonzeros_out = ad.jvp_jaxpr( body_jaxpr, body_nonzeros, instantiate=carry_nz) if nonzeros_out == carry_nz: break carry_nz = _map(operator.or_, carry_nz, nonzeros_out) else: assert False, "Fixpoint not reached" nonzeros = cconst_nz + body_nonzeros tangents = [ad.instantiate_zeros(t) if nz else t for t, nz in zip(tangents, nonzeros)] cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts]) _, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts]) bconst_dot = _prune_zeros(bconst_dot) init_dot = _prune_zeros(init_dot) num_carry = len(primals) - cond_nconsts - body_nconsts body_jvp_rearranged = ad.rearrange_binders( body_jvp, [body_nconsts, num_carry], [len(bconst_dot), len(init_dot)], [num_carry], [len(init_dot)]) newvar = core.gensym([cond_jaxpr.jaxpr]) invars_aug = ( cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot]) cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars, invars_aug, cond_jaxpr.jaxpr.outvars, cond_jaxpr.jaxpr.eqns, cond_jaxpr.jaxpr.effects) cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts) out = while_p.bind( *(cconst + bconst + bconst_dot + init + init_dot), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_augmented, body_nconsts=len(bconst) + len(bconst_dot), body_jaxpr=body_jvp_rearranged) out_carry, out_carry_dot = split_list(out, [num_carry]) out_tangents_iter = iter(out_carry_dot) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_carry, nonzeros_out)] return out_carry, out_tangents def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int, cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int, body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]: # As long as some carry (and hence output) are known and the output of # `cond_jaxpr` is known, we use a portion of the loop body to compute the # known outputs of the `while_loop`. For the unknown outputs we generate a # jaxpr to run the whole while, including recomputing the known parts, # basically like building in checkpointing/rematieralization. This means that # we don't actually save any computation by partial evaluation if there are # unknown outputs. # # What this achieves is twofold: jax.linearize works, and we can give a proper # error for reverse differentiation of `while`. unknowns = [not t.pval.is_known() for t in tracers] params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr) cond_consts_uk, body_consts_uk, carry_init_uk = \ split_list(unknowns, [cond_nconsts, body_nconsts]) # Fixpoint computation of unknown carry. Each iteration promotes at least one # carry to unknown. We need one last iteration to prepare the jaxpr. carry_uk = carry_init_uk for _ in range(1 + len(carry_uk)): body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( # type: ignore body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk) if carry_out_uk == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_out_uk) else: assert False, "Fixpoint not reached" cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False) if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns): # If conditional is unknown, or all inputs are known, or all are unknown, # just do the default processing. return trace.default_process_primitive(while_p, tracers, params) # Run the known part of the while. in_consts = [t.pval.get_known() for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk, tracers) if not uk] cond_nconsts_known = len(cond_consts_uk) - sum(cond_consts_uk) body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk) num_known_outs = len(carry_uk) - sum(carry_uk) # TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:num_known_outs] out_known = while_p.bind( *in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known, body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known) del body_jaxpr_known # Run the whole while_loop to get all the outputs, then merge with known ones out_tracers_ = trace.default_process_primitive(while_p, tracers, params) out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk] return util.merge_lists(carry_uk, out_known, out_tracers) # TODO(mattjj): de-duplicate code with _while_partial_eval def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): del saveable # We can't save any residuals anyway (w/o dynamic shapes)! cond_jaxpr = eqn.params['cond_jaxpr'] cond_nconsts = eqn.params['cond_nconsts'] body_jaxpr = eqn.params['body_jaxpr'] body_nconsts = eqn.params['body_nconsts'] cond_consts_uk, body_consts_uk, carry_init_uk = \ split_list(unks_in, [cond_nconsts, body_nconsts]) # Fixpoint to compute known part of the body (trivial on 'inst_in', since we # make all inputs available as DCE can subsequently prune any unused ones) carry_uk = carry_init_uk for _ in range(1 + len(carry_uk)): body_unks_in = body_consts_uk + carry_uk jaxpr_known_, _, carry_uk_out, _, num_res = \ pe.partial_eval_jaxpr_custom( body_jaxpr.jaxpr, in_unknowns=body_unks_in, in_inst=True, ensure_out_unknowns=carry_uk, ensure_out_inst=True, saveable=ad_checkpoint.nothing_saveable) if carry_uk_out == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" assert not num_res body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) del jaxpr_known_, carry_uk_out, num_res # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is core.Var and not inst] # Compute the known part of cond_fun (basically pruning inputs on known side). cond_unks_in = cond_consts_uk + carry_uk cond_jaxpr_known_, _, [cond_uk], _, _ = \ pe.partial_eval_jaxpr_custom( cond_jaxpr.jaxpr, cond_unks_in, in_inst=True, ensure_out_unknowns=False, ensure_out_inst=True, saveable=ad_checkpoint.nothing_saveable) # NOTE(mattjj): I think it should be impossible for the condition to be # unknown, but asserting that caused a test failure in diffrax. So # we handle it: if it is unknown, stage out the whole cond function. if cond_uk: return None, eqn, [True] * len(carry_uk), [True] * len(carry_uk), new_inst cond_jaxpr_known = core.ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts) del cond_uk # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(carry_uk, eqn.outvars) params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known, cond_nconsts=len(cond_consts_uk) - sum(cond_consts_uk), body_nconsts=len(body_consts_uk) - sum(body_consts_uk)) effects_known = core.join_effects(cond_jaxpr_known.effects, body_jaxpr_known.effects) eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p, params_known, effects_known, eqn.source_info) # Staged eqn is same as input eqn. eqn_staged = eqn unks_out = carry_uk inst_out = [True] * len(unks_out) return eqn_known, eqn_staged, unks_out, inst_out, new_inst def _while_transpose_error(*_, **kwargs): raise ValueError("Reverse-mode differentiation does not work for " "lax.while_loop or lax.fori_loop. " "Try using lax.scan instead.") # For a while loop with ordered effects in the cond, we need a special # lowering. Fundamentally, we'd like to rewrite a while loop that looks like # this: # ``` # while cond(x): # x = body(x) # ``` # into something that looks like this: # ``` # while True: # token, pred = cond(token, x) # if not pred: # break # token, x = body(token, x) # ``` # Unfortunately, with an MHLO while we can't (1) return multiple values # from a `cond` and (2) can't break a while loop. We thus adopt the # following rewrite strategy: # ``` # def new_cond(pred, token, x): # return pred # token, pred = cond(token, x) # while new_cond(pred, token, x): # token, x = body(token, x) # token, pred = cond(token, x) # ``` def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): pred_aval = cond_jaxpr.out_avals[0] batched = bool(pred_aval.shape) cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in core.ordered_effects] if cond_ordered_effects: def cond(args): return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0] def body(args): return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) def new_cond(pred_args): pred, _ = pred_args return pred def new_body(pred_args): _, args = pred_args args = body(args) pred = cond(args) return pred, args def fun(*args): pred = cond(args) _, out = while_loop(new_cond, new_body, (pred, args)) return out return mlir.lower_fun(fun)(ctx, *args) loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in) body_effects = [eff for eff in body_jaxpr.effects if eff in core.ordered_effects] num_tokens = len(body_effects) tokens = [ctx.tokens_in.get(eff) for eff in body_effects] token_types = [mlir.token_type() for _ in tokens] loop_carry_types = [*token_types, *loop_carry_types] flat_loop_carry_types = util.flatten(loop_carry_types) args = [*tokens, *args] flat_args = mlir.flatten_lowering_ir_args(args) while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args) # Loop condition cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types) name_stack = extend_name_stack(ctx.module_context.name_stack, 'while') with ir.InsertionPoint(cond_block): flat_cond_args = [ cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types)) # Remove tokens from cond args cond_args = cond_args[num_tokens:] x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_ctx = ctx.module_context.replace( name_stack=xla.extend_name_stack(name_stack, 'cond')) ((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), _map(mlir.ir_constants, cond_jaxpr.consts), *(x + z)) if batched: pred_ctx = mlir.LoweringRuleContext( module_context=ctx.module_context, primitive=None, avals_in=[pred_aval], avals_out=[pred_aval.update(shape=())], tokens_in=mlir.TokenSet(), tokens_out=None) pred, = lax._unary_reduce_lower( mhlo.OrOp, lambda dtype: np.array(False, dtype), pred_ctx, pred, axes=tuple(range(len(pred_aval.shape)))) mhlo.ReturnOp([pred]) # Loop body body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types) with ir.InsertionPoint(body_block): flat_body_args = [ body_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types)) # Tokens are at the front of the args list to the while loop token_args, body_args = util.split_list(body_args, [num_tokens]) tokens_in = mlir.TokenSet(zip(body_effects, token_args)) x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts]) body_ctx = ctx.module_context.replace( name_stack=xla.extend_name_stack(name_stack, 'body')) new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, tokens_in, _map(mlir.ir_constants, body_jaxpr.consts), *(y + z)) out_tokens = [tokens_out.get(eff) for eff in body_effects] if batched: body_pred_ctx = ctx.module_context.replace( name_stack=xla.extend_name_stack(name_stack, 'body_pred')) ((body_pred,),), _ = mlir.jaxpr_subcomp( body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), _map(mlir.ir_constants, cond_jaxpr.consts), *(x + z)) new_z = _map( partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x), *util.flatten(y), *util.flatten(new_z)]) outputs = util.unflatten(while_op.results, _map(len, loop_carry_types)) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) if tokens: ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens))) return z def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): # TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) if joined_effects - allowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {joined_effects - allowed_effects}') return body_jaxpr.out_avals, joined_effects while_p = core.AxisPrimitive('while') while_p.multiple_results = True while_p.def_impl(partial(xla.apply_primitive, while_p)) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) ad.primitive_jvps[while_p] = _while_loop_jvp pe.custom_partial_eval_rules[while_p] = _while_partial_eval xla.register_initial_style_primitive(while_p) ad.primitive_transposes[while_p] = _while_transpose_error batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) core.custom_typechecks[while_p] = _while_typecheck def _pred_bcast_select_mhlo( pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value], ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]: if x_y_aval is core.abstract_token: x, = xs y, = ys return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result] else: assert isinstance(x_y_aval, core.ShapedArray), x_y_aval x, = xs y, = ys assert x.type == y.type, (x.type, y.type) assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), ( pred_aval.shape, x_y_aval) bcast_pred = mhlo.BroadcastInDimOp( mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))), pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result return mhlo.SelectOp(bcast_pred, x, y).results ### fori_loop def _fori_cond_fun(loop_carry): i, upper, _ = loop_carry return lax.lt(i, upper) @weakref_lru_cache def _fori_body_fun(body_fun): body_fun = weakref.ref(body_fun) def while_body_fun(loop_carry): i, upper, x = loop_carry return lax.add(i, lax._const(i, 1)), upper, body_fun()(i, x) return while_body_fun @weakref_lru_cache def _fori_scan_body_fun(body_fun): body_fun = weakref.ref(body_fun) def scanned_fun(loop_carry, _): i, x = loop_carry return (i + 1, body_fun()(i, x)), None return scanned_fun
[docs]@api_boundary def fori_loop(lower, upper, body_fun, init_val): """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. The `Haskell-like type signature`_ in brief is .. code-block:: haskell fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a The semantics of ``fori_loop`` are given by this Python implementation:: def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val Unlike that Python version, ``fori_loop`` is implemented in terms of either a call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the trip count is static (meaning known at tracing time, perhaps because ``lower`` and ``upper`` are Python integer literals) then the ``fori_loop`` is implemented in terms of ``scan`` and reverse-mode autodiff is supported; otherwise, a ``while_loop`` is used and reverse-mode autodiff is not supported. See those functions' docstrings for more information. Also unlike the Python analogue, the loop-carried value ``val`` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type ``a`` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves). .. note:: :py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: lower: an integer representing the loop index lower bound (inclusive) upper: an integer representing the loop index upper bound (exclusive) body_fun: function of type ``(int, a) -> a``. init_val: initial loop carry value of type ``a``. Returns: Loop value from the final iteration, of type ``a``. .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ if not callable(body_fun): raise TypeError("lax.fori_loop: body_fun argument should be callable.") # TODO(phawkins): perhaps do more type checking here, better error messages. lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower)) upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper)) if lower_dtype != upper_dtype: msg = ("lower and upper arguments to fori_loop must have equal types, " "got {} and {}") raise TypeError(msg.format(lower_dtype.name, upper_dtype.name)) # If we can specialize on the trip count, call scan instead of a while_loop # to enable efficient reverse-mode differentiation. if (isinstance(core.get_aval(lower), ConcreteArray) and isinstance(core.get_aval(upper), ConcreteArray)): try: lower_ = int(lower) upper_ = int(upper) except TypeError: use_scan = False else: use_scan = True else: use_scan = False if use_scan: if config.jax_disable_jit and upper_ == lower_: # non-jit implementation of scan does not support length=0 return init_val (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val), None, length=upper_ - lower_) else: _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun), (lower, upper, init_val)) return result
### map and miscellanous rules
[docs]@api_boundary def map(f, xs): """Map a function over leading array axes. Like Python's builtin map, except inputs and outputs are in the form of stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you need to apply a function element by element for reduced memory usage or heterogeneous computation with other control flow primitives. When ``xs`` is an array type, the semantics of ``map`` are given by this Python implementation:: def map(f, xs): return np.stack([f(x) for x in xs]) Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of the same advantages over a Python loop apply: ``xs`` may be an arbitrary nested pytree type, and the mapped computation is compiled only once. Args: f: a Python function to apply element-wise over the first axis or axes of ``xs``. xs: values over which to map along the leading axis. Returns: Mapped values. """ g = lambda _, x: ((), f(x)) _, ys = scan(g, (), xs) return ys
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): """Calls RBG in a loop and stacks the results.""" key, = batched_args bd, = batch_dims if bd is batching.not_mapped: return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype, algorithm=algorithm), (None, None) key = batching.moveaxis(key, bd, 0) map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm) stacked_keys, stacked_bits = map(map_body, key) return (stacked_keys, stacked_bits), (0, 0) batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore ### associative_scan
[docs]@api_boundary def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): """Performs a scan with an associative binary operation, in parallel. For an introduction to associative scans, see [BLE1990]_. Args: fn: A Python callable implementing an associative binary operation with signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it must satisfy the equation ``fn(a, fn(b, c)) == fn(fn(a, b), c)``. The inputs and result are (possibly nested Python tree structures of) array(s) matching ``elems``. Each array has a dimension in place of the ``axis`` dimension. `fn` should be applied elementwise over the ``axis`` dimension (for example, by using :func:`jax.vmap` over the elementwise function.) The result ``r`` has the same shape (and structure) as the two inputs ``a`` and ``b``. elems: A (possibly nested Python tree structure of) array(s), each with an ``axis`` dimension of size ``num_elems``. reverse: A boolean stating if the scan should be reversed with respect to the ``axis`` dimension. axis: an integer identifying the axis over which the scan should occur. Returns: A (possibly nested Python tree structure of) array(s) of the same shape and structure as ``elems``, in which the ``k``'th element of ``axis`` is the result of recursively applying ``fn`` to combine the first ``k`` elements of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``, the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``. Example 1: partial sums of an array of numbers: >>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) DeviceArray([0, 1, 3, 6], dtype=int32) Example 2: partial products of an array of matrices >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2) Example 3: reversed partial sums of an array of numbers >>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) DeviceArray([6, 6, 5, 3], dtype=int32) .. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University. """ if not callable(fn): raise TypeError("lax.associative_scan: fn argument should be callable.") elems_flat, tree = tree_flatten(elems) if reverse: elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat] def combine(a_flat, b_flat): # Lower `fn` to operate on flattened sequences of elems. a = tree_unflatten(tree, a_flat) b = tree_unflatten(tree, b_flat) c = fn(a, b) c_flat, _ = tree_flatten(c) return c_flat # Check that all inputs have a consistent leading dimension `num_elems`. axis = util.canonicalize_axis(axis, elems_flat[0].ndim) num_elems = int(elems_flat[0].shape[axis]) if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): raise ValueError('Array inputs to associative_scan must have the same ' 'first dimension. (saw: {})' .format([elem.shape for elem in elems_flat])) # Summary of algorithm: # # Consider elements of `_scan(elems)` at odd indices. That's the same as first # summing successive pairs of elements of `elems` and performing a scan on # that half sized tensor. We perform the latter scan by recursion. # # Now consider the even elements of `_scan(elems)`. These can be computed # from the odd elements of `_scan(elems)` by adding each odd element of # `_scan(elems)` to the matching even element in the original `elems`. # # We return the odd and even elements interleaved. # # For the base case of the recursion we return the first element # of `elems` followed by the sum of the first two elements computed as # a (small two-down-to-one) reduction step. def _scan(elems): """Perform scan on `elems`.""" num_elems = elems[0].shape[axis] if num_elems < 2: return elems # Combine adjacent pairs of elements. reduced_elems = combine( [slicing.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems], [slicing.slice_in_dim(elem, 1, None, stride=2, axis=axis) for elem in elems]) # Recursively compute scan for partially reduced tensors. odd_elems = _scan(reduced_elems) if num_elems % 2 == 0: even_elems = combine( [slicing.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems], [slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems]) else: even_elems = combine( odd_elems, [slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems]) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [ lax.concatenate([slicing.slice_in_dim(elem, 0, 1, axis=axis), result], dimension=axis) for (elem, result) in zip(elems, even_elems)] return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems)) scans = _scan(elems_flat) if reverse: scans = [lax.rev(scanned, [axis]) for scanned in scans] return tree_unflatten(tree, scans)
def _interleave(a, b, axis): """Given two Tensors of static shape, interleave them along the first axis.""" assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 a_pad = [(0, 0, 0)] * a.ndim b_pad = [(0, 0, 0)] * b.ndim a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1) b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1) op = lax.bitwise_or if a.dtype == np.bool_ else lax.add return op(lax.pad(a, lax._const(a, 0), a_pad), lax.pad(b, lax._const(b, 0), b_pad)) ### Cumulative reductions.
[docs]def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative sum along `axis`.""" return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
[docs]def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative product along `axis`.""" return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
[docs]def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative maximum along `axis`.""" return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
[docs]def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative minimum along `axis`.""" return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def _cumred_shape_rule(x, *, axis: int, reverse: bool): if axis < 0 or axis >= x.ndim: raise ValueError( f"axis {axis} is out of bounds for array of shape {x.shape}") return x.shape def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool): return [cumsum(t, axis=axis, reverse=not reverse)] def cumred_tpu_impl(window_reduce: Callable, x, *, axis: int, reverse: bool): # On TPU, an implementation using reduce_window is handled specially by the # compiler and is efficient. On other backends, it is O(n^2). n = x.shape[axis] if n == 0: return x padding = [(0, 0)] * x.ndim padding[axis] = (0, n - 1) if reverse else (n - 1, 0) strides = [1] * x.ndim window_dims = [1] * x.ndim window_dims[axis] = n return window_reduce(x, window_dims, strides, padding) def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int, reverse: bool): operand, = batched_args bdim, = batch_dims axis = axis if axis < bdim else axis + 1 return prim.bind(operand, axis=axis, reverse=reverse), bdim def _cumred_dtype_rule(name, operand, *args, **kw): if not dtypes.issubdtype(operand.dtype, np.number): raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes " "of number.".format(name, np.dtype(operand.dtype).name)) return dtypes.canonicalize_dtype(operand.dtype) def _cumulative_reduction_primitive(name, reduce_fn, tpu_reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), name) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) mlir.register_lowering( reducer_p, mlir.cache_lowering( mlir.lower_fun(partial(associative_scan, reduce_fn), multiple_results=False))) mlir.register_lowering( reducer_p, mlir.lower_fun(partial(cumred_tpu_impl, tpu_reduce_window_fn), multiple_results=False), platform='tpu') return reducer_p cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum) ad.deflinear2(cumsum_p, _cumsum_transpose_rule) cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, windowed_reductions._reduce_window_prod) cummax_p = _cumulative_reduction_primitive("cummax", lax.max, windowed_reductions._reduce_window_max) cummin_p = _cumulative_reduction_primitive("cummin", lax.min, windowed_reductions._reduce_window_min) def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool, combine_fn: Callable): # Irrespective of backend, we always use the parallel prefix scan # implementation when differentiating because reduce_window is not # arbitrarily differentiable. return api.jvp(partial(associative_scan, combine_fn, axis=axis, reverse=reverse), primals, tangents) ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul) ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min) ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)