- jax.lax.associative_scan(fn, elems, reverse=False, axis=0)#
Performs a scan with an associative binary operation, in parallel.
For an introduction to associative scans, see [BLE1990].
A Python callable implementing an associative binary operation with signature
r = fn(a, b). Function fn must be associative, i.e., it must satisfy the equation
fn(a, fn(b, c)) == fn(fn(a, b), c).
The inputs and result are (possibly nested Python tree structures of) array(s) matching
elems. Each array has a dimension in place of the
axisdimension. fn should be applied elementwise over the
axisdimension (for example, by using
jax.vmap()over the elementwise function.)
rhas the same shape (and structure) as the two inputs
elems – A (possibly nested Python tree structure of) array(s), each with an
axisdimension of size
bool) – A boolean stating if the scan should be reversed with respect to the
int) – an integer identifying the axis over which the scan should occur.
A (possibly nested Python tree structure of) array(s) of the same shape and structure as
elems, in which the
k’th element of
axisis the result of recursively applying
fnto combine the first
axis. For example, given
elems = [a, b, c, ...], the result would be
[a, fn(a, b), fn(fn(a, b), c), ...].
Example 1: partial sums of an array of numbers:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) Array([0, 1, 3, 6], dtype=int32)
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2)
Example 3: reversed partial sums of an array of numbers
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) Array([6, 6, 5, 3], dtype=int32)
Blelloch, Guy E. 1990. “Prefix Sums and Their Applications.”, Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.