jax.extend.linear_util.cache

Contents

jax.extend.linear_util.cache#

jax.extend.linear_util.cache(call, *, explain=None)[source]#

Memoization decorator for functions taking a WrappedFun as first argument.

Parameters:
  • call (Callable) – a Python callable that takes a WrappedFun as its first argument. The underlying transforms and params on the WrappedFun are used as part of the memoization cache key.

  • explain (Callable | None)

Returns:

A memoized version of call.