- jax.linearize(fun, *primals)#
Produces a linear approximation to
jvp()and partial eval.
Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.
primals – The primal values at which the Jacobian of
funshould be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of
- Return type
A pair where the first element is the value of
f(*primals)and the second element is a function that evaluates the (forward-mode) Jacobian-vector product of
primalswithout re-doing the linearization work.
y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that
linearize()uses partial evaluation so that the function
fis not re-linearized on calls to
f_jvp. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed,
linearize()has a similar signature to
This function is mainly useful if you want to apply
f_jvpmultiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize using
vmap(), as in:
pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
Here’s a more complete example of using
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (DeviceArray(3.26819, dtype=float32, weak_type=True), DeviceArray(-5.00753, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704