jax.linear_transpose#
- jax.linear_transpose(fun, *primals, reduce_axes=())[source]#
Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to
vjp()
, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals
, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimals
that match the full range of desired outputs from the transposed function. Integer dtypes are not supported.- Parameters
fun (
Callable
) – the linear function to be transposed.*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals)
. These arguments may be real scalars/ndarrays, but that is not required: only theshape
anddtype
attributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding cotangent. Otherwise, the transposed function will be per-example over named axes. For example, if'batch'
is a named batch axis,linear_transpose(f, *args, reduce_axes=('batch',))
will create a transpose function that sums over the batch whilelinear_transpose(f, args)
will create a per-example transpose.
- Return type
- Returns
A callable that calculates the transpose of
fun
. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals)
. Output will be a tuple, with the same shape/dtypes/structure asprimals
.
>>> import jax >>> import types >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (Array(0.5, dtype=float32), Array(-0.5, dtype=float32))