jax.experimental.shard_map.shard_map#

jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[source]#

Map a function over shards of data.

Note

shard_map is an experimental API, and still subject to change. For an introduction to sharded data, refer to Introduction to parallel programming. For a more in-depth look at using shard_map, refer to SPMD multi-device parallelism with shard_map.

Parameters:
  • f (Callable) – callable to be mapped. Each application of f, or “instance” of f, takes as input a shard of the mapped-over arguments and produces a shard of the output.

  • mesh (Mesh | AbstractMesh) – a jax.sharding.Mesh representing the array of devices over which to shard the data and on which to execute instances of f. The names of the Mesh can be used in collective communication operations in f. This is typically created by a utility function like jax.experimental.mesh_utils.create_device_mesh().

  • in_specs (Specs) – a pytree with PartitionSpec instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar to NamedSharding, each PartitionSpec represents how the corresponding argument (or subtree of arguments) should be sharded along the named axes of mesh. In each PartitionSpec, mentioning a mesh axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If an argument, or argument subtree, has a corresponding spec of None, that argument is not sharded.

  • out_specs (Specs) – a pytree with PartitionSpec instances as leaves, with a tree structure that is a tree prefix of the output of f. Each PartitionSpec represents how the corresponding output shards should be concatenated. In each PartitionSpec, metioning a mesh axis name at a position expresses concatenation of that mesh axis’s shards along the corresponding positional axis. Not mentioning a mesh axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced.

  • check_rep (bool) – If True (default) enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in out_specs are consistent with how the outputs of f are replicated. Must be set False if using a Pallas kernel in f.

  • auto (frozenset[AxisName]) – (experimental) an optional set of axis names from mesh over which we do not shard the data or map the function, but rather we allow the compiler to control sharding. These names cannot be used in in_specs, out_specs, or in communication collectives in f.

Returns:

A callable that applies the input function f across data sharded according to the mesh and in_specs.

Examples

For examples, refer to Introduction to parallel programming or SPMD multi-device parallelism with shard_map.