jax.lax.associative_scanΒΆ

jax.lax.
associative_scan
(fn, elems, reverse=False)[source]ΒΆ Perform a scan with an associative binary operation, in parallel.
 Parameters
fn β
Python callable implementing an associative binary operation with
signature
r = fn(a, b)
. This must satisfy associativity:fn(a, fn(b, c)) == fn(fn(a, b), c)
. The inputs and result are (possibly nested structures of) array(s) matchingelems
. Each array has a leading dimension in place ofnum_elems
; the fn is expected to be scanned over this dimension. The result r has the same shape (and structure) as the two inputsa
andb
.elems β A (possibly nested structure of) array(s), each with leading dimension
num_elems
.reverse β A boolean stating if the scan should be reversed with respect to the leading dimension.
 Returns
 A (possibly nested structure of) array(s) of the same shape
and structure as
elems
, in which thek``th element is the result of recursively applying ``fn
to combine the firstk
elements ofelems
. For example, givenelems = [a, b, c, ...]
, the result would be[a, fn(a, b), fn(fn(a, b), c), ...]
.
 Return type
result
Example 1: partial sums of an array of numbers:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) [ 0, 1, 3, 6]
Example 2: partial products of an array of matrices
>>> mats = random.uniform(random.PRNGKey(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2)
Example 3: reversed partial sums of an array of numbers
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) [ 6, 6, 5, 3]