jax.extend.linear_util.transformation

Contents

jax.extend.linear_util.transformation#

jax.extend.linear_util.transformation(gen, fun, *gen_static_args) = functools.partial(<class 'functools.partial'>, <function transformation>)[source]#

Adds one more transformation to a WrappedFun.

Parameters:
  • gen – the transformation generator function

  • fun (WrappedFun) – a WrappedFun on which to apply the transformation

  • gen_static_args – static args for the generator function

Return type:

WrappedFun