jax.lax.associative_scanΒΆ

jax.lax.associative_scan(fn, elems, reverse=False)[source]ΒΆ

Perform a scan with an associative binary operation, in parallel.

Parameters
  • fn –

    Python callable implementing an associative binary operation with

    signature r = fn(a, b). This must satisfy associativity: fn(a, fn(b, c)) == fn(fn(a, b), c). The inputs and result are (possibly nested structures of) array(s) matching elems. Each array has a leading dimension in place of num_elems; the fn is expected to be scanned over this dimension. The result r has the same shape (and structure) as the two inputs a and b.

  • elems – A (possibly nested structure of) array(s), each with leading dimension num_elems.

  • reverse – A boolean stating if the scan should be reversed with respect to the leading dimension.

Returns

A (possibly nested structure of) array(s) of the same shape

and structure as elems, in which the k``th element is the result of recursively applying ``fn to combine the first k elements of elems. For example, given elems = [a, b, c, ...], the result would be [a, fn(a, b), fn(fn(a, b), c), ...].

Return type

result

Example 1: partial sums of an array of numbers:

>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
[ 0, 1, 3, 6]

Example 2: partial products of an array of matrices

>>> mats = random.uniform(random.PRNGKey(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)

Example 3: reversed partial sums of an array of numbers

>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
[ 6, 6, 5, 3]