jax.linearize#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[False] = False) tuple[Any, Callable] [source]#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]
Produces a linear approximation to
fun
usingjvp()
and partial eval.- Parameters:
fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.
primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters offun
.has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be linearized, and the second is auxiliary data. Default False.
- Returns:
If
has_aux
isFalse
, returns a pair where the first element is the value off(*primals)
and the second element is a function that evaluates the (forward-mode) Jacobian-vector product offun
evaluated atprimals
without re-doing the linearization work. Ifhas_aux
isTrue
, returns a(primals_out, lin_fn, aux)
tuple whereaux
is the auxiliary data returned byfun
.
In terms of values computed,
linearize()
behaves much like a curriedjvp()
, where these two code blocks compute the same values:y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that
linearize()
uses partial evaluation so that the functionf
is not re-linearized on calls tof_jvp
. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed,linearize()
has a similar signature tovjp()
!)This function is mainly useful if you want to apply
f_jvp
multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize usingvmap()
, as in:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using
vmap()
andjvp()
together like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by bothlinearize()
andvjp()
.Here’s a more complete example of using
linearize()
:>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704