jax.extend.linear_util.transformation#
- jax.extend.linear_util.transformation = functools.partial(<class 'functools.partial'>, <function transformation>)[source]#
Adds one more transformation to a WrappedFun.
- Parameters:
gen – the transformation generator function
fun (WrappedFun) – a WrappedFun on which to apply the transformation
gen_static_args – static args for the generator function
- Return type: