jax.extend.linear_util module

jax.extend.linear_util module#


WrappedFun(f, transforms, stores, params, ...)

Represents a function f to which transforms are to be applied.

cache(call, *[, explain])

Memoization decorator for functions taking a WrappedFun as first argument.

merge_linear_aux(aux1, aux2)

transformation(gen, fun, *gen_static_args)

Adds one more transformation to a WrappedFun.

transformation_with_aux(gen, fun, ...[, ...])

Adds one more transformation with auxiliary output to a WrappedFun.

wrap_init(f[, params])

Wraps function f as a WrappedFun, suitable for transformation.