jax.make_jaxpr

Contents

jax.make_jaxpr#

jax.make_jaxpr(fun, static_argnums=(), axis_env=None, return_shape=False, abstracted_axes=None)[source]#

Creates a function that produces its jaxpr given example args.

Parameters:
  • fun (Callable) – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.

  • static_argnums (int | Iterable[int]) – See the jax.jit() docstring.

  • axis_env (Sequence[tuple[AxisName, int]] | None) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap().

  • return_shape (bool) – Optional boolean, defaults to False. If True, the wrapped function returns a pair where the first element is the ClosedJaxpr representation of fun and the second element is a pytree with the same structure as the output of fun and where the leaves are objects with shape, dtype, and named_shape attributes representing the corresponding types of the output leaves.

Return type:

Callable[…, core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]

Returns:

A wrapped version of fun that when applied to example arguments returns a ClosedJaxpr representation of fun on those arguments. If the argument return_shape is True, then the returned function instead returns a pair where the first element is the ClosedJaxpr representation of fun and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of fun.

A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simply-typed first-order lambda calculus with let-bindings. make_jaxpr() adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally. The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.

We do not describe the semantics of the jaxpr language in detail here, but instead give a few examples.

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
    b:f32[] = cos a
    c:f32[] = sin a
    _:f32[] = sin b
    d:f32[] = cos b
    e:f32[] = mul 1.0 d
    f:f32[] = neg e
    g:f32[] = mul f c
  in (g,) }
Parameters:

abstracted_axes (Any | None)