pmin(x, axis_name, *, axis_index_groups=None)¶
Compute an all-reduce min on
xover the pmapped axis
xis a pytree then the result is equivalent to mapping this function to each leaf in the tree.
x – array(s) with a mapped axis named
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).
axis_index_groups – optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size.
Array(s) with the same shape as
xrepresenting the result of an all-reduce min along the axis