jax.linear_transpose#

jax.linear_transpose(fun, *primals, reduce_axes=())[source]#

Transpose a function that is promised to be linear.

For linear functions, this transformation is equivalent to vjp, but avoids the overhead of computing the forward pass.

The outputs of the transposed function will always have the exact same dtypes as primals, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes in primals that match the full range of desired outputs from the transposed function. Integer dtypes are not supported.

Parameters
  • fun (Callable) – the linear function to be transposed.

  • *primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of fun(*primals). These arguments may be real scalars/ndarrays, but that is not required: only the shape and dtype attributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding cotangent. Otherwise, the transposed function will be per-example over named axes. For example, if 'batch' is a named batch axis, linear_transpose(f, *args, reduce_axes=('batch',)) will create a transpose function that sums over the batch while linear_transpose(f, args) will create a per-example transpose.

Return type

Callable

Returns

A callable that calculates the transpose of fun. Valid input into this function must have the same shape/dtypes/structure as the result of fun(*primals). Output will be a tuple, with the same shape/dtypes/structure as primals.

>>> import jax
>>> import types
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))