# jax.lax.psum#

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[source]#

Compute an all-reduce sum on `x` over the pmapped axis `axis_name`.

If `x` is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

Inputs of boolean dtype are converted to integers before the reduction.

Parameters
• x – array(s) with a mapped axis named `axis_name`.

• axis_name – hashable Python object used to name a pmapped axis (see the `jax.pmap()` documentation for more details).

• axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once.

Returns

Array(s) with the same shape as `x` representing the result of an all-reduce sum along the axis `axis_name`.

Examples

For example, with 4 XLA devices available:

```>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]
```

Suppose we want to perform `psum` among two groups, one with `device0` and `device1`, the other with device2 and device3,

```>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[1 1 5 5]
```

An example using 2D-shaped x. Each row is data from one device.

```>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
[ 4  5  6  7]
[ 8  9 10 11]
[12 13 14 15]]
```

Full `psum` across all devices:

```>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[[24 28 32 36]
[24 28 32 36]
[24 28 32 36]
[24 28 32 36]]
```

Perform `psum` among two groups:

```>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[[ 4  6  8 10]
[ 4  6  8 10]
[20 22 24 26]
[20 22 24 26]]
```