jax.lax.psum#
- jax.lax.psum(x, axis_name, *, axis_index_groups=None)[source]#
Compute an all-reduce sum on
x
over the pmapped axisaxis_name
.If
x
is a pytree then the result is equivalent to mapping this function to each leaf in the tree.Inputs of boolean dtype are converted to integers before the reduction.
- Parameters
x – array(s) with a mapped axis named
axis_name
.axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()
documentation for more details).axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once.
- Returns
Array(s) with the same shape as
x
representing the result of an all-reduce sum along the axisaxis_name
.
Examples
For example, with 4 XLA devices available:
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [0. 0.16666667 0.33333334 0.5 ]
Suppose we want to perform
psum
among two groups, one withdevice0
anddevice1
, the other with device2 and device3,>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x) >>> print(y) [1 1 5 5]
An example using 2D-shaped x. Each row is data from one device.
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]]
Full
psum
across all devices:>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [[24 28 32 36] [24 28 32 36] [24 28 32 36] [24 28 32 36]]
Perform
psum
among two groups:>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x) >>> print(y) [[ 4 6 8 10] [ 4 6 8 10] [20 22 24 26] [20 22 24 26]]