jax.extend.linear_util.cache#
- jax.extend.linear_util.cache(call, *, explain=None)[源代码]#
用于记忆第一个参数为 WrappedFun 的函数的记忆化装饰器。
- 参数:
call (Callable) – 一个 Python 可调用对象,其第一个参数是 WrappedFun。WrappedFun 上的底层转换和参数将用作记忆化缓存键的一部分。
explain (Callable[[WrappedFun, bool, dict, tuple, float], None] | None) – 一个函数,在缓存未命中时被调用,以记录未命中的解释。调用时传入 (fun, is_cache_first_use, cache, key, elapsed_sec)。
- 返回:
call的记忆化版本。