jax.make_jaxpr

Contents

jax.make_jaxpr#

jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[Hashable, int]] | None = None, return_shape: Literal[False] = False, abstracted_axes: Any | None = None) Callable[[...], ClosedJaxpr][source]#
jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[Hashable, int]] | None = None, return_shape: Literal[True] = False, abstracted_axes: Any | None = None) Callable[[...], tuple[ClosedJaxpr, Any]]

Creates a function that produces its jaxpr given example args.

Parameters:
  • fun – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.

  • static_argnums – See the jax.jit() docstring.

  • axis_env – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap().

  • return_shape – Optional boolean, defaults to False. If True, the wrapped function returns a pair where the first element is the ClosedJaxpr representation of fun and the second element is a pytree with the same structure as the output of fun and where the leaves are objects with shape, dtype, and named_shape attributes representing the corresponding types of the output leaves.

Returns:

A wrapped version of fun that when applied to example arguments returns a ClosedJaxpr representation of fun on those arguments. If the argument return_shape is True, then the returned function instead returns a pair where the first element is the ClosedJaxpr representation of fun and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of fun.

A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simply-typed first-order lambda calculus with let-bindings. make_jaxpr() adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally. The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.

We do not describe the semantics of the jaxpr language in detail here, but instead give a few examples.

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
    b:f32[] = cos a
    c:f32[] = sin a
    _:f32[] = sin b
    d:f32[] = cos b
    e:f32[] = mul 1.0 d
    f:f32[] = neg e
    g:f32[] = mul f c
  in (g,) }