jax.extend.linear_util
module#
|
Represents a function f to which transforms are to be applied. |
|
Memoization decorator for functions taking a WrappedFun as first argument. |
|
|
Adds one more transformation to a WrappedFun. |
|
Adds one more transformation with auxiliary output to a WrappedFun. |
|
|
Wraps function f as a WrappedFun, suitable for transformation. |