jax.lax.psumΒΆ

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[source]ΒΆ

Compute an all-reduce sum on x over the pmapped axis axis_name.

If x is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

Inputs of boolean dtype are converted to integers before the reduction.

Parameters
  • x – array(s) with a mapped axis named axis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.

Returns

Array(s) with the same shape as x representing the result of an all-reduce sum along the axis axis_name.

For example, with 4 XLA devices available:

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]