- jax.experimental.pjit.pjit(fun, in_axis_resources=<jax.interpreters.pxla._UnspecifiedValue object>, out_axis_resources=<jax.interpreters.pxla._UnspecifiedValue object>, static_argnums=(), donate_argnums=())#
funcompiled and automatically partitioned across multiple devices.
The returned function has semantics equivalent to those of
fun, but is compiled to an XLA computation that runs across multiple devices (e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted version of
funwould not fit in a single device’s memory, or to speed up
funby running each operation in parallel across multiple devices.
The partitioning over devices happens automatically based on the propagation of the input partitioning specified in
in_axis_resourcesand the output partitioning specified in
out_axis_resources. The resources specified in those two arguments must refer to mesh axes, as defined by the
jax.experimental.maps.Mesh()context manager. Note that the mesh definition at
pjitapplication time is ignored, and the returned function will use the mesh definition available at each call site.
Inputs to a pjit’d function will be automatically partitioned across devices if they’re not already correctly partitioned based on
in_axis_resources. In some scenarios, ensuring that the inputs are already correctly pre-partitioned can increase performance. For example, if passing the output of one pjit’d function to another pjit’d function (or the same pjit’d function in a loop), make sure the relevant
out_axis_resourcesmatch the corresponding
Multi-process platforms: On multi-process platforms such as TPU pods,
pjitcan be used to run computations across all available devices across processes. To achieve this,
pjitis designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the same pjit’d function in the same order.
When running in this configuration, the mesh should contain devices across all processes. However, any input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, and outputs will be similarly sized according to the local mesh.
funwill still be executed across all devices in the mesh, including those from other processes, and will be given a global view of the data spread across multiple processes as a single array. However, outside of
pjitevery process only “sees” its local piece of the input and output, corresponding to its local sub-mesh.
This means that each process’s participating local devices must form a _contiguous_ local sub-mesh within the full global mesh. A contiguous sub-mesh is one where all of its devices are adjacent within the global mesh, and form a rectangular prism.
The SPMD model also requires that the same multi-process
pjit’d functions must be run in the same order on all processes, but they can be interspersed with arbitrary operations running in a single process.
Callable) – Function to be compiled. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by
static_argnumscan be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.
Pytree of structure matching that of arguments to
fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.
- The valid resource assignment specifications are:
None, in which case the value will be replicated on all devices
PartitionSpec, a tuple of length at most equal to the rank of the partitioned value. Each element can be a
None, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec.
The size of every dimension has to be a multiple of the total number of resources assigned to it.
out_axis_resources – Like
in_axis_resources, but specifies resource assignment for function outputs.
An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__eq__are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.
static_argnumsis not provided, no arguments are treated as static.
Specify which argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
- Return type
A wrapped version of
fun, set up for just-in-time compilation and automaticly partitioned by the mesh available at each call site.
For example, a convolution operator can be automatically partitioned over an arbitrary set of devices by a single
>>> import jax >>> import jax.numpy as jnp >>> import numpy as np >>> from jax.experimental.maps import Mesh >>> from jax.experimental.pjit import PartitionSpec, pjit >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'), ... in_axis_resources=None, out_axis_resources=PartitionSpec('devices')) >>> with Mesh(np.array(jax.devices()), ('devices',)): ... print(f(x)) [ 0.5 2. 4. 6. 8. 10. 12. 10. ]