- jax.lax.pmax(x, axis_name, *, axis_index_groups=None)#
Compute an all-reduce max on
xover the pmapped axis
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.
x – array(s) with a mapped axis named
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).
axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first two and last two replicas). Groups must cover all axis indices exactly once, and on TPUs all groups must be the same size.
Array(s) with the same shape as
xrepresenting the result of an all-reduce max along the axis