# Source code for jax.experimental.jet

# Copyright 2020 The JAX Authors.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

r"""Jet is an experimental module for higher-order automatic differentiation
that does not rely on repeated first-order automatic differentiation.

How? Through the propagation of truncated Taylor polynomials.
Consider a function :math:f = g \circ h, some point :math:x
and some offset :math:v.
First-order automatic differentiation (such as :func:jax.jvp)
computes the pair :math:(f(x), \partial f(x)[v]) from the pair
:math:(h(x), \partial h(x)[v]).

:func:jet implements the higher-order analogue:
Given the tuple

.. math::
(h_0, ... h_K) :=
(h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),

which represents a :math:K-th order Taylor approximation
of :math:h at :math:x, :func:jet returns a :math:K-th order
Taylor approximation of :math:f at :math:x,

.. math::
(f_0, ..., f_K) :=
(f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).

More specifically, :func:jet computes

.. math::
f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))

and can thus be used for high-order
automatic differentiation of :math:f.
Details are explained in
these notes <https://github.com/google/jax/files/6717197/jet.pdf>__.

Note:
Help improve :func:jet by contributing
outstanding primitive rules <https://github.com/google/jax/issues/2431>__.
"""

from typing import Any, Callable

from functools import partial

import numpy as np

from jax import lax
import jax.numpy as jnp
from jax.experimental import pjit
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten,)

from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import sharding_impls
from jax._src.api_util import shaped_abstractify
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache

[docs]def jet(fun, primals, series):
r"""Taylor-mode higher-order automatic differentiation.

Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Taylor approximation of fun should be
evaluated. Should be either a tuple or a list of arguments,
and its length should be equal to the number of positional parameters of
fun.
series: Higher order Taylor-series-coefficients.
Together, primals and series make up a truncated Taylor polynomial.
Should be either a tuple or a list of tuples or lists,
and its length dictates the degree of the truncated Taylor polynomial.

Returns:
A (primals_out, series_out) pair, where primals_out is fun(*primals),
and together, primals_out and series_out are a
truncated Taylor polynomial of :math:f(h(\cdot)).
The primals_out value has the same Python tree structure as primals,
and the series_out value the same Python tree structure as series.

For example:

>>> import jax
>>> import jax.numpy as np

Consider the function :math:h(z) = z^3, :math:x = 0.5,
and the first few Taylor coefficients
:math:h_0=x^3, :math:h_1=3x^2, and :math:h_2=6x.
Let :math:f(y) = \sin(y).

>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

:func:jet returns the Taylor coefficients of :math:f(h(z)) = \sin(z^3)
according to FaÃ  di Bruno's formula:

>>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
>>> print(f0,  f(h0))
0.12467473 0.12467473

>>> print(f1, df(h0) * h1)
0.7441479 0.74414825

>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
2.9064622 2.9064634
"""
try:
order, = set(map(len, series))
except ValueError:
msg = "jet terms have inconsistent lengths for different arguments"
raise ValueError(msg) from None

# TODO(mattjj): consider supporting pytree inputs
for i, (x, terms) in enumerate(zip(primals, series)):
treedef = tree_structure(x)
if not treedef_is_leaf(treedef):
raise ValueError(f"primal value at position {i} is not an array")
for j, t in enumerate(terms):
treedef = tree_structure(t)
if not treedef_is_leaf(treedef):
raise ValueError(f"term {j} for argument {i} is not an array")

@lu.transformation_with_aux
def flatten_fun_output(*args):
ans = yield args, {}
yield tree_flatten(ans)

f, out_tree = flatten_fun_output(lu.wrap_init(fun))
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)

@lu.transformation
def jet_fun(order, primals, series):
with core.new_main(JetTrace) as main:
main.order = order
out_primals, out_terms = yield (main, primals, series), {}
del main
out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s
for p, s in zip(out_primals, out_terms)]
yield out_primals, out_terms

@lu.transformation
def jet_subtrace(main, primals, series):
trace = JetTrace(main, core.cur_sublevel())
in_tracers = map(partial(JetTracer, trace), primals, series)
ans = yield in_tracers, {}
out_tracers = map(trace.full_raise, ans)
out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
yield out_primals, out_terms

@lu.transformation_with_aux
def traceable(in_tree_def, *primals_and_series):
primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
primals_out, series_out = yield (primals_in, series_in), {}
out_flat, out_tree_def = tree_flatten((primals_out, series_out))
yield out_flat, out_tree_def

class JetTracer(core.Tracer):
__slots__ = ["primal", "terms"]

def __init__(self, trace, primal, terms):
assert type(terms) in (ZeroSeries, list, tuple)
self._trace = trace
self.primal = primal
self.terms = terms

@property
def aval(self):
return core.get_aval(self.primal)

def full_lower(self):
if self.terms is zero_series or all(t is zero_term for t in self.terms):
return core.full_lower(self.primal)
else:
return self

class JetTrace(core.Trace):

def pure(self, val):
return JetTracer(self, val, zero_series)

def lift(self, val):
return JetTracer(self, val, zero_series)

def sublift(self, val):
return JetTracer(self, val.primal, val.terms)

def process_primitive(self, primitive, tracers, params):
order = self.main.order              # pytype: disable=attribute-error
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
series_in = [[zero_term] * order if s is zero_series else s
for s in series_in]
# TODO(mattjj): avoid always instantiating zeros
series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x))
if t is zero_term else t for t in series]
for x, series in zip(primals_in, series_in)]
rule = jet_rules[primitive]
primal_out, terms_out = rule(primals_in, series_in, **params)
if not primitive.multiple_results:
return JetTracer(self, primal_out, terms_out)
else:
return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)]

def process_call(self, call_primitive, f, tracers, params):
primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
update_params = call_param_updaters.get(call_primitive)
new_params = (update_params(params, len(primals_and_series))
if update_params else params)
result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
primals_out, series_out = tree_unflatten(out_tree_def(), result)
return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]

def post_process_call(self, call_primitive, out_tracers, params):
primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
out, treedef = tree_flatten((primals, series))
del primals, series
main = self.main
def todo(x):
primals, series = tree_unflatten(treedef, x)
trace = JetTrace(main, core.cur_sublevel())
return map(partial(JetTracer, trace), primals, series)
return out, todo

def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
# TODO(mattjj): don't just ignore custom jvp rules?
del primitive, jvp  # Unused.
return fun.call_wrapped(*tracers)

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees  # Unused.
return fun.call_wrapped(*tracers)

class ZeroTerm: pass
zero_term = ZeroTerm()
register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term)

class ZeroSeries: pass
zero_series = ZeroSeries()
register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)

call_param_updaters: dict[core.Primitive, Callable[..., Any]] = {}

### rule definitions

jet_rules = {}

def defzero(prim):
jet_rules[prim] = partial(zero_prop, prim)

def zero_prop(prim, primals_in, series_in, **params):
primal_out = prim.bind(*primals_in, **params)
return primal_out, zero_series

defzero(lax.le_p)
defzero(lax.lt_p)
defzero(lax.gt_p)
defzero(lax.ge_p)
defzero(lax.eq_p)
defzero(lax.ne_p)
defzero(lax.not_p)
defzero(lax.and_p)
defzero(lax.or_p)
defzero(lax.xor_p)
defzero(lax.floor_p)
defzero(lax.ceil_p)
defzero(lax.round_p)
defzero(lax.sign_p)
defzero(lax.is_finite_p)
defzero(lax.shift_left_p)
defzero(lax.shift_right_arithmetic_p)
defzero(lax.shift_right_logical_p)
defzero(lax.bitcast_convert_type_p)

def deflinear(prim):
jet_rules[prim] = partial(linear_prop, prim)

def linear_prop(prim, primals_in, series_in, **params):
primal_out = prim.bind(*primals_in, **params)
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
return primal_out, series_out

deflinear(lax.neg_p)
deflinear(lax.real_p)
deflinear(lax.complex_p)
deflinear(lax.conj_p)
deflinear(lax.imag_p)
deflinear(lax.sub_p)
deflinear(lax.convert_element_type_p)
deflinear(lax.concatenate_p)
deflinear(lax.reshape_p)
deflinear(lax.squeeze_p)
deflinear(lax.rev_p)
deflinear(lax.transpose_p)
deflinear(lax.slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax.fft_p)
deflinear(dispatch.device_put_p)

def _dynamic_slice_jet_rule(primals_in, series_in, **params):
operand, *start_indices = primals_in
primal_out = lax.dynamic_slice_p.bind(operand, *start_indices, **params)
series_out = [lax.dynamic_slice_p.bind(terms_in[0], *start_indices, **params)
for terms_in in zip(*series_in)]
return primal_out, series_out

jet_rules[lax.dynamic_slice_p] = _dynamic_slice_jet_rule

def _dynamic_update_slice_jet_rule(primals_in, series_in, **params):
operand, update, *start_indices = primals_in
primal_out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
series_out = [lax.dynamic_update_slice_p.bind(*terms_in[:2], *start_indices, **params)
for terms_in in zip(*series_in)]
return primal_out, series_out

jet_rules[lax.dynamic_update_slice_p] = _dynamic_update_slice_jet_rule

def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return jet(partial(lax.associative_scan, combine_fn, axis=axis,
reverse=reverse),
primals_in, series_in)

deflinear(lax.cumsum_p)
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
combine_fn=lax.mul)
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
combine_fn=lax.max)
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
combine_fn=lax.min)

def def_deriv(prim, deriv):
"""
Define the jet rule for a primitive in terms of its first derivative.
"""
jet_rules[prim] = partial(deriv_prop, prim, deriv)

def deriv_prop(prim, deriv, primals_in, series_in):
x, = primals_in
series, = series_in
primal_out = prim.bind(x)
c0, cs = jet(deriv, primals_in, series_in)
c = [c0] + cs
u = [x] + series
v = [primal_out] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
primal_out, *series_out = v
return primal_out, series_out

def_deriv(lax.erf_p,
lambda x: lax.mul(
lax_internal._const(x, 2. / np.sqrt(np.pi)),
lax.exp(lax.neg(lax.square(x)))))

def def_comp(prim, comp):
"""
Define the jet rule for a primitive in terms of a composition of simpler primitives.
"""
jet_rules[prim] = partial(jet, comp)

def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))

def _erf_inv_rule(primals_in, series_in):
x, = primals_in
series, = series_in

u = [x] + series
primal_out = lax.erf_inv(x)
v = [primal_out] + [None] * len(series)

# derivative on co-domain for caching purposes
deriv_const = np.sqrt(np.pi) / 2.
deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

# manually propagate through deriv_y since we don't have lazy evaluation of sensitivities

c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
for k in range(1, len(series)):
# we know c[:k], we compute c[k]

# propagate c to get v
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))

# propagate v to get next c

# square
tmp_sq[k] = fact(k) * sum(_scale2(k, j) * v[k-j] * v[j] for j in range(k + 1))

# exp
tmp_exp[k] = fact(k-1) * sum(_scale(k, j) * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1))

# const
c[k] = deriv_const * tmp_exp[k]

# we can't, and don't need, to compute c[k+1], just need to get the last v[k]
k = len(series)
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))

primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.erf_inv_p] = _erf_inv_rule

### More complicated rules

def fact(n):
return lax.exp(lax.lgamma(n+1.))

def _scale(k, j):
return 1. / (fact(k - j) * fact(j - 1))

def _scale2(k, j):
return 1. / (fact(k - j) * fact(j))

def _exp_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.exp(x)] + [None] * len(series)
for k in range(1,len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.exp_p] = _exp_taylor

def _pow_taylor(primals_in, series_in):
u_, r_ = primals_in

x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)

u = [x] + series
v = [u_ ** r_] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k+1))
primal_out, *series_out = v

return primal_out, series_out
jet_rules[lax.pow_p] = _pow_taylor

def _integer_pow_taylor(primals_in, series_in, *, y):
if y == 0:
return jet(jnp.ones_like, primals_in, series_in)
elif y == 1:
return jet(lambda x: x, primals_in, series_in)
elif y == 2:
return jet(lambda x: x * x, primals_in, series_in)
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.integer_pow(x, y)] + [None] * len(series)
for k in range(1, len(v)):
vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1))
uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k))
v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x)
primal_out, *series_out = v

return primal_out, series_out
jet_rules[lax.integer_pow_p] = _integer_pow_taylor

def _logistic_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.logistic(x)] + [None] * len(series)
e = [v[0] * (1 - v[0])] + [None] * len(series)  # terms for sigmoid' = sigmoid * (1 - sigmoid)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1))
e[k] = (1 - v[0]) * v[k] - fact(k) * sum(_scale2(k, j) * v[j] * v[k-j] for j in range(1, k+1))

primal_out, *series_out = v
return primal_out, series_out

jet_rules[lax.logistic_p] = _logistic_taylor

def _tanh_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [2*x] + [2 * series_ for series_ in series]
primals_in, *series_in = u
primal_out, series_out = _logistic_taylor((primals_in, ), (series_in, ))
series_out = [2 * series_ for series_ in series_out]
return 2 * primal_out - 1, series_out
jet_rules[lax.tanh_p] = _tanh_taylor

def _log_taylor(primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
v = [lax.log(x)] + [None] * len(series)
for k in range(1, len(v)):
conv = sum(_scale(k, j) * v[j] * u[k-j] for j in range(1, k))
v[k] = (u[k] - fact(k - 1) * conv) / u[0]
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.log_p] = _log_taylor

def _atan2_taylor(primals_in, series_in):
x, y = primals_in
primal_out = lax.atan2(x, y)

x, series = jet(lax.div, primals_in, series_in)
one = lax_internal._const(x, 1)
c0, cs = jet(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, ))
c = [c0] + cs
u = [x] + series
v = [primal_out] + [None] * len(series)
for k in range(1, len(v)):
v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.atan2_p] = _atan2_taylor

def _div_taylor_rule(primals_in, series_in):
x, y = primals_in
x_terms, y_terms = series_in
u = [x] + x_terms
w = [y] + y_terms
v = [None] * len(u)
def scale(k, j): return 1. / (fact(k - j) * fact(j))
for k in range(0, len(v)):
conv = sum(scale(k, j) * v[j] * w[k-j] for j in range(0, k))
v[k] = (u[k] - fact(k) * conv) / w[0]
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.div_p] = _div_taylor_rule

def _sinusoidal_rule(sign, prims, primals_in, series_in):
x, = primals_in
series, = series_in
u = [x] + series
s, c = prims
s = [s(x)] + [None] * len(series)
c = [c(x)] + [None] * len(series)
for k in range(1, len(s)):
s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1))
c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign
return (s[0], s[1:]), (c[0], c[1:])

def _get_ind(f, ind):
return lambda *args: f(*args)[ind]

jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0)
jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1)

def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
x, y = primals_in
x_terms, y_terms = series_in
u = [x] + x_terms
w = [y] + y_terms
v = [None] * len(u)
op = partial(prim.bind, **params)
def scale(k, j): return 1. / (fact(k - j) * fact(j))
for k in range(0, len(v)):
v[k] = fact(k) * sum(scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1))
primal_out, *series_out = v
return primal_out, series_out
jet_rules[lax.dot_general_p] = partial(_bilinear_taylor_rule, lax.dot_general_p)
jet_rules[lax.mul_p] = partial(_bilinear_taylor_rule, lax.mul_p)
jet_rules[lax.conv_general_dilated_p] = partial(_bilinear_taylor_rule, lax.conv_general_dilated_p)

def _gather_taylor_rule(primals_in, series_in, **params):
operand, start_indices = primals_in
gs, _ = series_in
primal_out = lax.gather_p.bind(operand, start_indices, **params)
series_out = [lax.gather_p.bind(g, start_indices, **params) for g in gs]
return primal_out, series_out
jet_rules[lax.gather_p] = _gather_taylor_rule

def _gen_reduce_choose_taylor_rule(chooser_fun):
def chooser_taylor_rule(primals_in, series_in, **params):
operand, = primals_in
gs, = series_in
primal_out = chooser_fun(operand, **params)
axes = params.pop("axes", None)
primal_dtype = gs[0].dtype
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
location_indicators = lax.convert_element_type(
lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
primal_dtype)
counts = lax_internal._reduce_sum(location_indicators, axes)
def _reduce_chooser_taylor_rule(g):
return lax.div(
lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
return chooser_taylor_rule
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_max)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_min)

def _abs_taylor_rule(x, series_in, **params):
x, = x
zero = lax.full_like(x, 0, shape=())
primal_out = lax.abs_p.bind(x, **params)
negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
fix_sign = lambda y: negs * y
series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.abs_p] = _abs_taylor_rule

def _select_n_taylor_rule(primal_in, series_in, **params):
b, *cases = primal_in
primal_out = lax.select_n(b, *cases)
sel = lambda _, *xs: lax.select_n(b, *xs)
series_out = [sel(*terms_in) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.select_n_p] = _select_n_taylor_rule

def _lax_max_taylor_rule(primal_in, series_in):

xgy = x > y   # greater than mask
xey = x == y  # equal to mask
primal_out = lax.select(xgy, x, y)

def select_max_and_avg_eq(x_i, y_i):
"""Select x where x>y or average when x==y"""
max_i = lax.select(xgy, x_i, y_i)
max_i = lax.select(xey, (x_i + y_i)/2, max_i)
return max_i

series_out = [select_max_and_avg_eq(*jnp.broadcast_arrays(*terms_in)) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.max_p] = _lax_max_taylor_rule

def _lax_min_taylor_rule(primal_in, series_in):
x, y = primal_in
xgy = x < y   # less than mask
xey = x == y  # equal to mask
primal_out = lax.select(xgy, x, y)

def select_min_and_avg_eq(x_i, y_i):
"""Select x where x>y or average when x==y"""
min_i = lax.select(xgy, x_i, y_i)
min_i = lax.select(xey, (x_i + y_i)/2, min_i)
return min_i

series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
return primal_out, series_out
jet_rules[lax.min_p] = _lax_min_taylor_rule

def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted, unique_indices,
mode):
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
series_out = [bind(d1, scatter_indices, d2) for d1, _, d2 in zip(*series_in)]
return primal_out, series_out

@weakref_lru_cache
def _jet_jaxpr(
jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
) -> tuple[core.ClosedJaxpr, Any]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic(
f_jet, primals_and_series_avals)
return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def

def _pjit_jet_rule(primals_in, series_in, **params):
primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
order = len(series_in[0])
primals_and_series_avals = tuple(shaped_abstractify(x) for x in primals_and_series)
jaxpr_jet, out_tree_def = _jet_jaxpr(params['jaxpr'], order,
primals_and_series_avals, in_tree_def)
num_series_in = len(primals_in) * order
num_series_out = len(params['out_shardings']) * order
new_params = {
**params,
'jaxpr': jaxpr_jet,
'in_shardings': (
params['in_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_in
),
'out_shardings': (
params['out_shardings']
+ (sharding_impls.UNSPECIFIED,) * num_series_out
),
'donated_invars': params['donated_invars'] + (False,) * num_series_in,
}
result = pjit.pjit_p.bind(*primals_and_series, **new_params)
return tree_unflatten(out_tree_def(), result)

jet_rules[pjit.pjit_p] = _pjit_jet_rule