jax.extend.linear_util.transformation_with_aux

jax.extend.linear_util.transformation_with_aux#

jax.extend.linear_util.transformation_with_aux(gen, fun, *gen_static_args, use_eq_store=False) = functools.partial(<class 'functools.partial'>, <function transformation_with_aux>)[source]#

Adds one more transformation with auxiliary output to a WrappedFun. :param fun: :type fun: WrappedFun

Return type:

tuple[WrappedFun, Any]