jax.extend.linear_util.cache#

jax.extend.linear_util.cache(call, *, explain=None)[source]#

用于记忆化将 WrappedFun 作为第一个参数的函数的装饰器。

参数:
  • call (Callable) – 一个 Python 可调用对象,它将 WrappedFun 作为其第一个参数。WrappedFun 上的底层转换和参数用作记忆化缓存键的一部分。

  • explain (Callable[[WrappedFun, bool, dict, tuple, float], None] | None | None) – 一个函数,当缓存未命中时调用该函数以记录未命中的解释。使用 (fun, is_cache_first_use, cache, key, elapsed_sec) 调用。

返回:

call 的记忆化版本。