jax.extend.linear_util.cache#

jax.extend.linear_util.cache(call, *, explain=None)[source]#

用于接受 WrappedFun 作为第一个参数的函数的记忆化装饰器。

参数:
  • call (可调用对象) – 一个 Python 可调用对象,它以 WrappedFun 作为其第一个参数。WrappedFun 上的底层变换和参数被用作记忆化缓存键的一部分。

  • explain (可调用对象[[WrappedFun, 布尔值, 字典, 元组, 浮点数], None] | None) – 一个在缓存未命中时调用的函数,用于记录未命中的解释。调用时传入 (fun, is_cache_first_use, cache, key, elapsed_sec)

返回:

call 的记忆化版本。