jax.experimental.shard_map.shard_map#
- jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[source]#
Map a function over shards of data.
Note
shard_map
is an experimental API, and still subject to change. For an introduction to sharded data, refer to Introduction to parallel programming. For a more in-depth look at usingshard_map
, refer to SPMD multi-device parallelism with shard_map.- Parameters:
f (Callable) – callable to be mapped. Each application of
f
, or “instance” off
, takes as input a shard of the mapped-over arguments and produces a shard of the output.mesh (Mesh | AbstractMesh) – a
jax.sharding.Mesh
representing the array of devices over which to shard the data and on which to execute instances off
. The names of theMesh
can be used in collective communication operations inf
. This is typically created by a utility function likejax.experimental.mesh_utils.create_device_mesh()
.in_specs (Specs) – a pytree with
PartitionSpec
instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar toNamedSharding
, eachPartitionSpec
represents how the corresponding argument (or subtree of arguments) should be sharded along the named axes ofmesh
. In eachPartitionSpec
, mentioning amesh
axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If an argument, or argument subtree, has a corresponding spec of None, that argument is not sharded.out_specs (Specs) – a pytree with
PartitionSpec
instances as leaves, with a tree structure that is a tree prefix of the output off
. EachPartitionSpec
represents how the corresponding output shards should be concatenated. In eachPartitionSpec
, metioning amesh
axis name at a position expresses concatenation of that mesh axis’s shards along the corresponding positional axis. Not mentioning amesh
axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced.check_rep (bool) – If True (default) enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in
out_specs
are consistent with how the outputs off
are replicated. Must be set False if using a Pallas kernel inf
.auto (frozenset[AxisName]) – (experimental) an optional set of axis names from
mesh
over which we do not shard the data or map the function, but rather we allow the compiler to control sharding. These names cannot be used inin_specs
,out_specs
, or in communication collectives inf
.
- Returns:
A callable that applies the input function
f
across data sharded according to themesh
andin_specs
.
Examples
For examples, refer to Introduction to parallel programming or SPMD multi-device parallelism with shard_map.