- jax.linear_transpose(fun, *primals, reduce_axes=())#
Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to
vjp, but avoids the overhead of computing the forward pass.
The outputs of the transposed function will always have the exact same dtypes as
primals, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes in
primalsthat match the full range of desired outputs from the transposed function. Integer dtypes are not supported.
Callable) – the linear function to be transposed.
*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals). These arguments may be real scalars/ndarrays, but that is not required: only the
dtypeattributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)
reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
funimplicitly broadcasts a value over that axis, the backward pass will perform a
psumof the corresponding cotangent. Otherwise, the transposed function will be per-example over named axes. For example, if
'batch'is a named batch axis,
linear_transpose(f, *args, reduce_axes=('batch',))will create a transpose function that sums over the batch while
linear_transpose(f, args)will create a per-example transpose.
- Return type
A callable that calculates the transpose of
fun. Valid input into this function must have the same shape/dtypes/structure as the result of
fun(*primals). Output will be a tuple, with the same shape/dtypes/structure as
>>> import jax >>> import types >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))