jax.stages package#

Interfaces to stages of the compiled execution process.

JAX transformations that compile just in time for execution, such as jax.jit and jax.pmap, also support a common means of explicit lowering and compilation ahead of time. This module defines types that represent the stages of this process.

For more, see the AOT walkthrough.

Classes#

class jax.stages.Wrapped(*args, **kwargs)[source]#

A function ready to be specialized, lowered, and compiled.

This protocol reflects the output of functions such as jax.jit. Calling it results in JIT (just-in-time) lowering, compilation, and execution. It can also be explicitly lowered prior to compilation, and the result compiled prior to execution.

__call__(*args, **kwargs)[source]#

Executes the wrapped function, lowering and compiling as needed.

lower(*args, **kwargs)[source]#

Lower this function explicitly for the given arguments.

A lowered function is staged out of Python and translated to a compiler’s input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled.

Return type

Lowered

Returns

A Lowered instance representing the lowering.

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[source]#

Lowering of a function specialized to argument types and values.

A lowering is a computation ready for compilation. This class carries a lowering together with the remaining information needed to later compile and execute it. It also provides a common API for querying properties of lowered computations across JAX’s various lowering paths (jit, pmap, etc.).

Parameters
  • lowering (XlaLowering) –

  • out_tree (PyTreeDef) –

  • no_kwargs (bool) –

as_text(dialect=None)[source]#

A human-readable text representation of this lowering.

Intended for visualization and debugging purposes. This need not be a valid nor reliable serialization. It is relayed directly to external callers.

Parameters

dialect (Optional[str]) – Optional string specifying a lowering dialect (e.g. “mhlo”)

Return type

str

compile()[source]#

Compile, returning a corresponding Compiled instance.

Return type

Compiled

compiler_ir(dialect=None)[source]#

An arbitrary object representation of this lowering.

Intended for debugging purposes. This is not a valid nor reliable serialization. The output has no guarantee of consistency across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Parameters

dialect (Optional[str]) – Optional string specifying a lowering dialect (e.g. “mhlo”)

Return type

Optional[Any]

property in_tree: jaxlib.xla_extension.pytree.PyTreeDef#

Tree structure of the pair (positional arguments, keyword arguments).

Return type

PyTreeDef

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[source]#

Compiled representation of a function specialized to types/values.

A compiled computation is associated with an executable and the remaining information needed to execute it. It also provides a common API for querying properties of compiled computations across JAX’s various compilation paths and backends.

__call__(*args, **kwargs)[source]#

Call self as a function.

as_text()[source]#

A human-readable text representation of this executable.

Intended for visualization and debugging purposes. This is not a valid nor reliable serialization.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type

Optional[str]

cost_analysis()[source]#

A summary of execution cost estimates.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type

Optional[Any]

property in_tree: jaxlib.xla_extension.pytree.PyTreeDef#

Tree structure of the pair (positional arguments, keyword arguments).

Return type

PyTreeDef

memory_analysis()[source]#

A summary of estimated memory requirements.

Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type

Optional[Any]

runtime_executable()[source]#

An arbitrary object representation of this executable.

Intended for debugging purposes. This is not valid nor reliable serialization. The output has no guarantee of consistency across invocations.

Returns None if unavailable, e.g. based on backend, compiler, or runtime.

Return type

Optional[Any]