jax.jit
jax.jit#
- jax.jit(fun, *, static_argnums=None, static_argnames=None, device=None, backend=None, donate_argnums=(), inline=False, keep_unused=False, abstracted_axes=None)[source]#
Sets up
fun
for just-in-time compilation with XLA.- Parameters
fun (
Callable
) βFunction to be jitted.
fun
should be a pure function, as side-effects may only be executed once.The arguments and return value of
fun
should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated bystatic_argnums
can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.JAX keeps a weak reference to
fun
for use as a compilation cache key, so the objectfun
must be weakly-referenceable. MostCallable
objects will already satisfy this requirement.static_argnums (
Union
[int
,Iterable
[int
],None
]) βAn optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__
and__eq__
are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If neither
static_argnums
norstatic_argnames
is provided, no arguments are treated as static. Ifstatic_argnums
is not provided butstatic_argnames
is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond tostatic_argnames
(or vice versa). If bothstatic_argnums
andstatic_argnames
are provided,inspect.signature
is not used, and only actual parameters listed in eitherstatic_argnums
orstatic_argnames
will be treated as static.static_argnames (
Union
[str
,Iterable
[str
],None
]) β An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment onstatic_argnums
for details. If not provided butstatic_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.device (
Optional
[Device
]) β This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved viajax.devices()
.) The default is inherited from XLAβs DeviceAssignment logic and is usually to usejax.devices()[0]
.backend (
Optional
[str
]) β This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu'
,'gpu'
, or'tpu'
.donate_argnums (
Union
[int
,Iterable
[int
]]) βSpecify which positional argument buffers are βdonatedβ to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated. Note that donate_argnums only work for positional arguments, and keyword arguments will not be donated.
For more details on buffer donation see the FAQ.
inline (
bool
) β Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.keep_unused (
bool
) β If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.
- Return type
- Returns
A wrapped version of
fun
, set up for just-in-time compilation.
Examples
In the following example,
selu
can be compiled into a single fused kernel by XLA:>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> >>> key = jax.random.PRNGKey(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ]
To pass arguments such as
static_argnames
when decorating a function, a common pattern is to usefunctools.partial()
:>>> from functools import partial >>> >>> @partial(jax.jit, static_argnames=['n']) ... def g(x, n): ... for i in range(n): ... x = x ** 2 ... return x >>> >>> g(jnp.arange(4), 3) Array([ 0, 1, 256, 6561], dtype=int32)