jax.experimental.pjit module

Contents

jax.experimental.pjit module#

API#

jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#

Makes fun compiled and automatically partitioned across multiple devices.

NOTE: This function is now equivalent to jax.jit please use that instead. The returned function has semantics equivalent to those of fun, but is compiled to an XLA computation that runs across multiple devices (e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted version of fun would not fit in a single device’s memory, or to speed up fun by running each operation in parallel across multiple devices.

The partitioning over devices happens automatically based on the propagation of the input partitioning specified in in_shardings and the output partitioning specified in out_shardings. The resources specified in those two arguments must refer to mesh axes, as defined by the jax.sharding.Mesh() context manager. Note that the mesh definition at pjit() application time is ignored, and the returned function will use the mesh definition available at each call site.

Inputs to a pjit()’d function will be automatically partitioned across devices if they’re not already correctly partitioned based on in_shardings. In some scenarios, ensuring that the inputs are already correctly pre-partitioned can increase performance. For example, if passing the output of one pjit()’d function to another pjit()’d function (or the same pjit()’d function in a loop), make sure the relevant out_shardings match the corresponding in_shardings.

Note

Multi-process platforms: On multi-process platforms such as TPU pods, pjit() can be used to run computations across all available devices across processes. To achieve this, pjit() is designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the same pjit()’d function in the same order.

When running in this configuration, the mesh should contain devices across all processes. However, any input argument dimensions partitioned over multi-process mesh axes should be of size equal to the corresponding local mesh axis size, and outputs will be similarly sized according to the local mesh. fun will still be executed across all devices in the mesh, including those from other processes, and will be given a global view of the data spread across multiple processes as a single array. However, outside of pjit() every process only “sees” its local piece of the input and output, corresponding to its local sub-mesh.

This means that each process’s participating local devices must form a _contiguous_ local sub-mesh within the full global mesh. A contiguous sub-mesh is one where all of its devices are adjacent within the global mesh, and form a rectangular prism.

The SPMD model also requires that the same multi-process pjit()’d functions must be run in the same order on all processes, but they can be interspersed with arbitrary operations running in a single process.

Parameters:
  • fun (Callable) – Function to be compiled. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.

  • in_shardings –

    Pytree of structure matching that of arguments to fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.

    The in_shardings argument is optional. JAX will infer the shardings from the input jax.Array’s, and defaults to replicating the input if the sharding cannot be inferred.

    The valid resource assignment specifications are:

    • XLACompatibleSharding, which will decide how the value will be partitioned. With this, using a mesh context manager is not required.

    • None is a special case whose semantics are:
      • if the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.

      • If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.

    • For backwards compatibility, in_shardings still supports ingesting PartitionSpec. This option can only be used with the mesh context manager.

      • PartitionSpec, a tuple of length at most equal to the rank of the partitioned value. Each element can be a None, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec.

    The size of every dimension has to be a multiple of the total number of resources assigned to it.

  • out_shardings – Like in_shardings, but specifies resource assignment for function outputs. The out_shardings argument is optional. If not specified, jax.jit() will use GSPMD’s sharding propagation to determine how to shard the outputs.

  • static_argnums (int | Sequence[int] | None) –

    An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.

    Static arguments should be hashable, meaning both __hash__ and __eq__ are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.

    If static_argnums is not provided, no arguments are treated as static.

  • static_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • donate_argnums (int | Sequence[int] | None) –

    Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.

    If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.

    For more details on buffer donation see the FAQ.

  • donate_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on donate_argnums for details. If not provided but donate_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.

  • device (xc.Device | None) – This argument is deprecated. Please put your arguments on the device you want before passing them to jit. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • backend (str | None) – This argument is deprecated. Please put your arguments on the backend you want before passing them to jit. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

Return type:

JitWrapped

Returns:

A wrapped version of fun, set up for just-in-time compilation and automatically partitioned by the mesh available at each call site.

For example, a convolution operator can be automatically partitioned over an arbitrary set of devices by a single pjit() application:

>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from jax.sharding import Mesh, PartitionSpec
>>> from jax.experimental.pjit import pjit
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
...         in_shardings=None, out_shardings=PartitionSpec('devices'))
>>> with Mesh(np.array(jax.devices()), ('devices',)):
...   print(f(x))  
[ 0.5  2.   4.   6.   8.  10.  12.  10. ]
Parameters:
  • inline (bool) –

  • abstracted_axes (Any | None) –