jax.extend.linear_util.transformation_with_aux#
- jax.extend.linear_util.transformation_with_aux = functools.partial(<class 'functools.partial'>, <function transformation_with_aux>)[source]#
Adds one more transformation with auxiliary output to a WrappedFun.
- Parameters:
fun (WrappedFun)
- Return type: