jax.lax.pmean(x, axis_name, *, axis_index_groups=None)[source]#

Compute an all-reduce mean on x over the pmapped axis axis_name.

If x is a pytree then the result is equivalent to mapping this function to each leaf in the tree.

  • x – array(s) with a mapped axis named axis_name.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first two and last two replicas). Groups must cover all axis indices exactly once, and on TPUs all groups must be the same size.


Array(s) with the same shape as x representing the result of an all-reduce mean along the axis axis_name.

For example, with 4 XLA devices available:

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x)
>>> print(y)
[1.5 1.5 1.5 1.5]
>>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.        0.6666667 1.3333334 2.       ]