jax.lax.associative_scan
jax.lax.associative_scan#
- jax.lax.associative_scan(fn, elems, reverse=False, axis=0)[source]#
Performs a scan with an associative binary operation, in parallel.
For an introduction to associative scans, see [BLE1990].
- Parameters
fn (
Callable
) βA Python callable implementing an associative binary operation with signature
r = fn(a, b)
. Function fn must be associative, i.e., it must satisfy the equationfn(a, fn(b, c)) == fn(fn(a, b), c)
.The inputs and result are (possibly nested Python tree structures of) array(s) matching
elems
. Each array has a dimension in place of theaxis
dimension. fn should be applied elementwise over theaxis
dimension (for example, by usingjax.vmap()
over the elementwise function.)The result
r
has the same shape (and structure) as the two inputsa
andb
.elems β A (possibly nested Python tree structure of) array(s), each with an
axis
dimension of sizenum_elems
.reverse (
bool
) β A boolean stating if the scan should be reversed with respect to theaxis
dimension.axis (
int
) β an integer identifying the axis over which the scan should occur.
- Returns
A (possibly nested Python tree structure of) array(s) of the same shape and structure as
elems
, in which thek
βth element ofaxis
is the result of recursively applyingfn
to combine the firstk
elements ofelems
alongaxis
. For example, givenelems = [a, b, c, ...]
, the result would be[a, fn(a, b), fn(fn(a, b), c), ...]
.
Example 1: partial sums of an array of numbers:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) Array([0, 1, 3, 6], dtype=int32)
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2)
Example 3: reversed partial sums of an array of numbers
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) Array([6, 6, 5, 3], dtype=int32)
- BLE1990
Blelloch, Guy E. 1990. βPrefix Sums and Their Applications.β, Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.