jax.lax.psumΒΆ

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[source]ΒΆ

Compute an all-reduce sum on x over the pmapped axis axis_name.

If x is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

Inputs of boolean dtype are converted to integers before the reduction.

Parameters
  • x – array(s) with a mapped axis named axis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once.

Returns

Array(s) with the same shape as x representing the result of an all-reduce sum along the axis axis_name.

For example, with 4 XLA devices available:

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]