jax.jit#
- jax.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#
Sets up
funfor just-in-time compilation with XLA.- Parameters:
fun (Callable) –
Function to be jitted.
funshould be a pure function, as side-effects may only be executed once.The arguments and return value of
funshould be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated bystatic_argnumscan be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.JAX keeps a weak reference to
funfor use as a compilation cache key, so the objectfunmust be weakly-referenceable. MostCallableobjects will already satisfy this requirement.in_shardings –
Pytree of structure matching that of arguments to
fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.The
in_shardingsargument is optional. JAX will infer the shardings from the inputjax.Array’s and defaults to replicating the input if the sharding cannot be inferred.- The valid resource assignment specifications are:
XLACompatibleSharding, which will decide how the valuewill be partitioned. With this, using a mesh context manager is not required.
None, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit’s in_shardings.
out_shardings –
Like
in_shardings, but specifies resource assignment for function outputs. This is similar to pjit’s out_shardings.The
out_shardingsargument is optional. If not specified,jax.jit()will use GSPMD’s sharding propagation to figure out what the sharding of the output(s) should be.static_argnums (int | Sequence[int] | None) –
An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__and__eq__are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If neither
static_argnumsnorstatic_argnamesis provided, no arguments are treated as static. Ifstatic_argnumsis not provided butstatic_argnamesis, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments that correspond tostatic_argnames(or vice versa). If bothstatic_argnumsandstatic_argnamesare provided,inspect.signatureis not used, and only actual parameters listed in eitherstatic_argnumsorstatic_argnameswill be treated as static.static_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on
static_argnumsfor details. If not provided butstatic_argnumsis set, the default is based on callinginspect.signature(fun)to find corresponding named arguments.donate_argnums (int | Sequence[int] | None) –
Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.
If neither
donate_argnumsnordonate_argnamesis provided, no arguments are donated. Ifdonate_argnumsis not provided butdonate_argnamesis, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments that correspond todonate_argnames(or vice versa). If bothdonate_argnumsanddonate_argnamesare provided,inspect.signatureis not used, and only actual parameters listed in eitherdonate_argnumsordonate_argnameswill be donated.For more details on buffer donation see the FAQ.
donate_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on
donate_argnumsfor details. If not provided butdonate_argnumsis set, the default is based on callinginspect.signature(fun)to find corresponding named arguments.keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.
device (xc.Device | None) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0].backend (str | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:
'cpu','gpu', or'tpu'.inline (bool) – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.
- Return type:
- Returns:
A wrapped version of
fun, set up for just-in-time compilation.
Examples
In the following example,
selucan be compiled into a single fused kernel by XLA:>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> >>> key = jax.random.PRNGKey(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ]
To pass arguments such as
static_argnameswhen decorating a function, a common pattern is to usefunctools.partial():>>> from functools import partial >>> >>> @partial(jax.jit, static_argnames=['n']) ... def g(x, n): ... for i in range(n): ... x = x ** 2 ... return x >>> >>> g(jnp.arange(4), 3) Array([ 0, 1, 256, 6561], dtype=int32)
- Parameters:
abstracted_axes (Any | None) –