JAX Internals: primitives#
Introduction to JAX primitives#
A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide).
For example, the multiply-add operation can be implemented in terms of the low-level jax.lax.*
primitives (which are like XLA operator wrappers) or jax.core.Primitive("multiply_add")
, as demonstrated further below.
And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as jax.jit()
, jax.grad()
and jax.vmap()
. JAX implements these transforms in a JAX-traceable way. This means that when a Python function is executed, the only operations it applies to the data are either:
Inspections of data attributes: Data information, such as shape or type; or
JAX primitives: These are the JAX special operations covered in this tutorial.
JAX primitives know how to operate on both concrete data values and abstract JAX values. A JAX-traceable function can be invoked by JAX with abstract arguments. For example, a JAX abstract value — ShapedArray(float32[2,2])
— captures the type and the shape of values, but not the concrete data values.
The JAX-transformed functions must themselves be JAX-traceable functions to make sure that these transformations are composable, for example like jax.jit(jax.jacfwd(jax.grad(f)))
.
JAX provides pre-defined primitives corresponding to most XLA operations, including add, matmul, sin, cos, and indexing.
In addition, JAX offers an implementation of NumPy functions in terms of JAX primitives. This means that Python programs using JAX’s implementation of NumPy are JAX-traceable and, therefore, transformable. Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.
Furthermore, the set of JAX primitives is extensible, so instead of reimplementing a function in terms of pre-defined JAX primitives, you can define a new primitive that encapsulates the behavior of the function.
Consider the following example: you want to add to JAX support for a multiply-add function with three arguments, defined mathematically as multiply_add(x, y, z) = x * y + z
. This function operates on 3 identically-shaped tensors of floating point values and performs the operations pointwise. You can do this by:
Using existing JAX primitives#
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other functions that are themselves written using JAX primitives, for example, those defined in the jax.lax()
module:
from jax import lax
from jax._src import api
def multiply_add_lax(x, y, z):
"""Implementation of multiply-add using the `jax.lax` primitives."""
return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
"""A square-add function using the newly defined multiply-add."""
return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax = 14.0
grad(square_add_lax) = 4.0
To understand how JAX is internally using the primitives, add some helpers for tracing function calls:
#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _trace_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_trace(msg)
_indentation = 1 + _indentation
def _trace_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation - 1
_trace(msg)
def trace(name):
"""A decorator for functions to trace arguments and results."""
def trace_func(func): # pylint: disable=missing-docstring
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
elif ("partial_eval.JaxprTracer" in vtype or
"batching.BatchTracer" in vtype or
"ad.JVPTracer" in vtype):
return "Traced<{}>".format(v.aval)
elif isinstance(v, tuple):
return "({})".format(pp_values(v))
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_trace_unindent("|<- {} = {}".format(name, pp(res)))
return res
return func_wrapper
return trace_func
class expectNotImplementedError(object):
"""Context manager to check for NotImplementedError."""
def __enter__(self): pass
def __exit__(self, type, value, tb):
global _indentation
_indentation = 0
if type is NotImplementedError:
print("\nFound expected exception:")
traceback.print_exc(limit=3)
return True
elif type is None: # No exception
assert False, "Expected NotImplementedError"
else:
return False
Instead of using jax.lax()
primitives directly, you can use other functions
that are already written in terms of those primitives, such as those in jax.numpy
:
import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
call multiply_add_numpy(2.0, 2.0, 10.0)
|<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy = 14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) = 4.0
Notice that in the process of computing jax.grad()
, JAX invokes square_add_numpy
and multiply_add_numpy
with special arguments ConcreteArray(...)
(described further below in this colab). It is important to remember that a JAX-traceable function must be able to operate not only on concrete arguments but also on special abstract arguments that JAX may use to abstract the function execution.
The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives.
Defining new JAX primitives#
The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality.
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAX-traceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
"""A square-add function implemented using the new JAX-primitive."""
return multiply_add_prim(a, a, b)
If you try to call the newly defined functions, you’ll get an error, because you haven’t yet told JAX anything about the semantics of the new primitive.
with expectNotImplementedError():
square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1156/2844449444.py", line 2, in <module>
square_add_prim(2., 10.)
File "/tmp/ipykernel_1156/1393342955.py", line 48, in func_wrapper
res = func(*args)
File "/tmp/ipykernel_1156/1751132419.py", line 17, in square_add_prim
return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented
Primal evaluation rules#
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
"""Concrete implementation of the primitive.
This function does not need to be JAX traceable.
Args:
x, y, z: The concrete arguments of the primitive. Will only be called with
concrete values.
Returns:
the concrete result of the primitive.
"""
# Note: you can use the ordinary (non-JAX) NumPy, which is not JAX-traceable.
return np.add(np.multiply(x, y), z)
# Now, register the primal implementation with JAX:
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0
What happens when you use jit
#
Now, if you try to use jit
, you’ll get a NotImplementedError
:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1156/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 356, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
Abstract evaluation rules#
To JIT the function, and for other transformations as well, JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes:
Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled.
Computes the shape and type of all vectors and operations used in the computation.
For example, the abstraction of a vector with 3 elements may be ShapedArray(float32[3])
, or ConcreteArray([1., 2., 3.])
. In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments
Args:
xs, ys, zs: Abstractions of the arguments.
Result:
a ShapedArray for the result of the primitive.
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return core.ShapedArray(xs.shape, xs.dtype)
# Now, register the abstract evaluation with JAX:
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
If you re-attempt to apply jit
, you can inspect how the abstract evaluation proceeds, but you’ll get another error about missing the actual XLA compilation rule:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1156/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 356, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
XLA Compilation rules#
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is the biggest hurdle to adding new functionality to JAX, because the set of XLA operations is limited, and JAX already has pre-defined primitives for most of them. However, XLA includes a CustomCall
operation that can be used to encapsulate arbitrary functionality defined using C++.
from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
the results of the function.
Does not need to be a JAX-traceable function.
"""
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
# Now, register the lowering rule with JAX.
# For GPU, refer to the https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>
You will now succeed to apply jax.jit
. Notice below that JAX first evaluates the function abstractly, which triggers the multiply_add_abstract_eval
function, and then compiles the set of primitives it has encountered, including multiply_add
. At this point JAX invokes multiply_add_lowering
.
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff605461b0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff6057b5b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff6057b630>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff6057b5f0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff6055c9a0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db277400>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1156/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <lambda> at 0x7eff617ec9d0, file "/tmp/ipykernel_1156/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1156/1570919344.py":1:0)), (<code object <module> at 0x7eff617ec7c0, file "/tmp/ipykernel_1156/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1156/1570919344.py":1:0)), (<code object run_code at 0x7eff9cbfe130, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7eff9cbfdfd0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7eff9cbfdc60, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7eff9cac4870, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/1570919344.py': '/tmp/ipykernel_1156/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff6055d780>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff60554370>]
Below is another use of jit
, where you compile only with respect to the first argument. Notice how the second argument to square_add_prim
is concrete, which leads in the third argument to multiply_add_abstract_eval
being ConcreteArray
. Notice that multiply_add_abstract_eval
may be used with both ShapedArray
and ConcreteArray
.
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff60588a40>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff60582bf0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff60582c70>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff60582c30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff6055ef50>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db277400>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_1156/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <lambda> at 0x7eff617edd10, file "/tmp/ipykernel_1156/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1156/4165789807.py":1:0)), (<code object <module> at 0x7eff617ed160, file "/tmp/ipykernel_1156/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1156/4165789807.py":1:0)), (<code object run_code at 0x7eff9cbfe130, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7eff9cbfdfd0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7eff9cbfdc60, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7eff9cac4870, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/4165789807.py': '/tmp/ipykernel_1156/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff6055f400>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff64077230>]
Forward differentiation#
JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it in Advanced automatic differentiation).
If you attempt to compute the jvp
function, you’ll get an error because you have not yet told JAX how to differentiate the multiply_add
primitive.
# The second argument is set to `(2., 10.)` values where you
# evaluate the Jacobian, and the third argument `(1., 1.)`
# contains the values of the tangents for the arguments.
with expectNotImplementedError():
api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1156/459539105.py", line 5, in <module>
api.jvp(square_add_prim, (2., 10.), (1., 1.))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1687, in jvp
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1716, in _jvp
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobian-vector product).
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAX-traceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: A tuple of arguments
arg_tangents: A tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value `ad.Zero` to specify a zero tangent
Returns:
A pair of the primal output and the tangent.
"""
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now, you have a JAX-traceable computation of the output.
# Normally, you can use the multiply add (`ma`) primitive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# You must use a JAX-traceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which you can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
# You do need to deal specially with `Zero`. Here, you just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for `Zero` and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX:
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, 1.0, 1.0)
call multiply_add_impl(2.0, 1.0, 1.0)
|<- multiply_add_impl = 3.0
|<- multiply_add_prim = 3.0
call multiply_add_prim(1.0, 2.0, 3.0)
call multiply_add_impl(1.0, 2.0, 3.0)
|<- multiply_add_impl = 5.0
|<- multiply_add_prim = 5.0
|<- multiply_add_value_and_jvp = (14.0, 5.0)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
JIT of forward differentiation#
You can apply jit
to the forward differentiation function:
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff605c7240>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff605d15b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff605d03b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff605d1cf0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff6055cf10>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db30ef30>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1156/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":27:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <lambda> at 0x7eff617edb00, file "/tmp/ipykernel_1156/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0)), (<code object <module> at 0x7eff617ecdf0, file "/tmp/ipykernel_1156/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1156/2145028508.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/347789876.py': '/tmp/ipykernel_1156/347789876.py', '/tmp/ipykernel_1156/2145028508.py': '/tmp/ipykernel_1156/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1156/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff981cf5e0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff61757a30>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff605c7240>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff605d15b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff605d03b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff605d1cf0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff6055cf10>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db30ef30>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1156/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5644db3bed20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1156/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":27:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <lambda> at 0x7eff617edb00, file "/tmp/ipykernel_1156/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0)), (<code object <module> at 0x7eff617ecdf0, file "/tmp/ipykernel_1156/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1156/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/347789876.py': '/tmp/ipykernel_1156/347789876.py', '/tmp/ipykernel_1156/2145028508.py': '/tmp/ipykernel_1156/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1156/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff981cfee0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff605833b0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff605c7240>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff605d15b0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff605d03b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff605d1cf0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff6055cf10>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db30ef30>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1156/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5644db3bed20>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1156/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5644db3fa340>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1156/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":27:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <lambda> at 0x7eff617edb00, file "/tmp/ipykernel_1156/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1156/2145028508.py":2:0)), (<code object <module> at 0x7eff617ecdf0, file "/tmp/ipykernel_1156/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1156/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/347789876.py': '/tmp/ipykernel_1156/347789876.py', '/tmp/ipykernel_1156/2145028508.py': '/tmp/ipykernel_1156/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1156/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff981ce740>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff605d3ef0>]
Notice that first, you evaluate multiply_add_value_and_jvp
abstractly, which in turn evaluates abstractly both the primal and the tangent evaluation (a total of 3 invocations of the ma
primitive). Then, you compile the 3 occurrences of the primitive.
Reverse differentiation#
If you attempt now to use reverse differentiation, you’ll notice that JAX starts by using the multiply_add_value_and_jvp
to compute the forward differentiation for abstract values, but then runs into a NotImplementedError
.
When computing the reverse differentiation, JAX first performs an abstract evaluation of the forward differentiation code multiply_add_value_and_jvp
to obtain a trace of primitives that compute the output tangent.
Observe that JAX performs this abstract evaluation with concrete values for the differentiation point, and abstract values for the tangents.
Notice that JAX uses the special abstract tangent value
Zero
for the tangent corresponding to the third argument ofma
. This reflects the fact that you do not differentiate w.r.t. the second argument tosquare_add_prim
, which flows to the third argument tomultiply_add_prim
.Notice also that during the abstract evaluation of the tangent you pass the value
0.0
as the tangent for the third argument. This is because of the use of themake_zero
function in the definition ofmultiply_add_value_and_jvp
.
# This is reverse differentiation w.r.t. the first argument of `square_add_prim`
with expectNotImplementedError():
api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 271, in get_primitive_transpose
return primitive_transposes[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1156/2155094905.py", line 3, in <module>
api.grad(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 392, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The above error is because there is a missing piece for JAX to be able to use the forward differentiation code to compute reverse differentiation.
Transposition#
As previously explained, when computing reverse differentiation, JAX obtains a trace of primitives that compute the tangent using forward differentiation. Then, JAX interprets this trace abstractly backwards and for each primitive it applies a transposition rule.
To understand what is going on, consider a simpler example of the function f(x, y) = x * y + y
. Assume, you need to differentiate at the point (2., 4.)
. JAX will produce the following JVP tangent calculation of ft
from the tangents of the input xt
and yt
:
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
By construction, the tangent calculation is always linear in the input tangents. The only non-linear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant.
JAX will produce the reverse differentiation computation by processing the JVP computation backwards. For each operation in the tangent computation, it accumulates the cotangents of the variables used by the operation, using the cotangent of the result of the operation:
# Initialize cotangents of inputs and intermediate variables:
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output:
fct = 1.
# Process `ft = c + yt`:
cct += fct
yct += fct
# Process `c = a + b`:
act += cct
bct += cct
# Process `b = 2. * yt`:
yct += 2. * bct
# Process `a = xt * 4.`:
xct += act * 4.
One can verify that this computation produces xct = 4.
and yct = 3.
, which
are the partial derivatives of the function f
.
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive p(x, y, z)
is linear in the arguments y
and z
for a constant value of x
, e.g., p(x, y, z) = y*cy + z*cz
, then the transposition of the primitive is:
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
Notice that p_transpose
takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined _
value, and for the other arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value None
returned for the constant arguments.
In particular:
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
`value_and_jvp`, and is only needed for primitives that are used in the JVP
calculation for some other primitive. You need a transposition for `multiply_add_prim`,
because you have used `multiply_add_prim` in the computation of the `output_tangent` in
`multiply_add_value_and_jvp`.
In this case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in `multiply_add_value_and_jvp`:
`output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))`.
Always one of the first two multiplicative arguments is a constant.
Args:
ct: The cotangent of the output of the primitive.
x, y, z: The values of the arguments. The arguments that are used linearly
get an ad.UndefinedPrimal value. The other arguments get a constant
value.
Returns:
A tuple with the cotangent of the inputs, with the value None
corresponding to the constant arguments.
"""
if not ad.is_undefined_primal(x):
# This use of multiply_add is with a constant "x".
assert ad.is_undefined_primal(y)
ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
res = None, ct_y, ct
else:
# This use of multiply_add is with a constant "y".
assert ad.is_undefined_primal(x)
ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
res = ct_x, None, ct
return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
Now you can complete the run of the grad
:
assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, 2.0, 0.0)
call multiply_add_impl(1.0, 2.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
call multiply_add_prim(2.0, 1.0, 0.0)
call multiply_add_impl(2.0, 1.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)
Notice the two calls to multiply_add_transpose
. They correspond to the two uses of multiply_add_prim
in the computation of the output_tangent
in multiply_add_value_and_jvp
. The first call to transpose corresponds to the last use of multiply_add_prim
: multiply_add_prim(xt, y, ...)
where y
is the constant 2.0
.
JIT of reverse differentiation#
Notice that the abstract evaluation of the multiply_add_value_and_jvp
is using only abstract values. Meanwhile, in the absence of JIT, you used ConcreteArray
.
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff605e2700>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff603feb70>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff603fef30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff603fecb0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff605f0fa0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db51f300>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1156/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <module> at 0x7eff60567730, file "/tmp/ipykernel_1156/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1156/3085343041.py":1:0)), (<code object run_code at 0x7eff9cbfe130, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/347789876.py': '/tmp/ipykernel_1156/347789876.py', '/tmp/ipykernel_1156/3085343041.py': '/tmp/ipykernel_1156/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1156/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff605f0dc0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff603fe670>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff605e2700>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff603feb70>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff603fef30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff603fecb0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff605f0fa0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db51f300>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1156/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x5644db659030>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1156/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <module> at 0x7eff60567730, file "/tmp/ipykernel_1156/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1156/3085343041.py":1:0)), (<code object run_code at 0x7eff9cbfe130, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7eff617ecc90, file "/tmp/ipykernel_1156/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1156/347789876.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/347789876.py': '/tmp/ipykernel_1156/347789876.py', '/tmp/ipykernel_1156/3085343041.py': '/tmp/ipykernel_1156/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1156/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff605f09d0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff617a1330>]
Batching#
The batching transformation takes a point-wise computation and turns it into a computation on vectors. If you try it right now, you will get a NotImplementedError
:
# The arguments are two vectors instead of two scalars.
with expectNotImplementedError():
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1156/1080163607.py", line 3, in <module>
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 992, in vmap_f
out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented
You need to instruct JAX how to evaluate the batched version of the primitive. In this particular case, the multiply_add_prim
already operates pointwise for any dimension of input vectors, so the batched version can use the same multiply_add_prim
implementation.
from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAX-traceable function.
Since the `multiply_add primitive` already operates point-wise on arbitrary
dimension tensors, to batch it you can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: A tuple of two arguments, each being a tensor of matching
shape.
batch_axes: The axes that are being batched. See vmap documentation.
Returns:
A tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]
_trace("Using multiply_add to compute the batch:")
res = multiply_add_prim(*vector_arg_values)
return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
|<- multiply_add_impl = [14. 29.]
|<- multiply_add_prim = [14. 29.]
|<- multiply_add_batch = ([14. 29.], 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
JIT of batching#
Below is an example of applying JIT to batching:
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
(np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
|<- multiply_add_abstract_eval = ShapedArray(float32[2])
|<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7eff60418220>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7eff6040ccb0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7eff6040cd30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7eff617dd670>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7eff641c2d40>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7eff605f1120>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x5644db6794d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_1156/1827752256.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1156/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7eff61798500, file "/tmp/ipykernel_1156/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1156/1751132419.py":12:0)), (<code object func_wrapper at 0x7eff6172baa0, file "/tmp/ipykernel_1156/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1156/1393342955.py":48:0)), (<code object multiply_add_batch at 0x7eff60565d10, file "/tmp/ipykernel_1156/1827752256.py", line 3>, 52): loc("multiply_add_batch"("/tmp/ipykernel_1156/1827752256.py":25:0)), (<code object square_add_prim at 0x7eff617989d0, file "/tmp/ipykernel_1156/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1156/1751132419.py":17:0)), (<code object <module> at 0x7eff60566600, file "/tmp/ipykernel_1156/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_1156/1392464762.py":1:0)), (<code object run_code at 0x7eff9cbfe130, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1156/1751132419.py': '/tmp/ipykernel_1156/1751132419.py', '/tmp/ipykernel_1156/1393342955.py': '/tmp/ipykernel_1156/1393342955.py', '/tmp/ipykernel_1156/1827752256.py': '/tmp/ipykernel_1156/1827752256.py', '/tmp/ipykernel_1156/1392464762.py': '/tmp/ipykernel_1156/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1156/1751132419.py': True, '/tmp/ipykernel_1156/1393342955.py': True, '/tmp/ipykernel_1156/1827752256.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1156/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7eff605f1360>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=False, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7eff8017ab30>]