jax.extend.linear_util.WrappedFun#
- class jax.extend.linear_util.WrappedFun(f, f_transformed, transforms, stores, params, in_type, debug_info)[源代码]#
表示函数 f,将要对其应用 transforms。
- 参数:
f (Callable) – 要转换的函数。
f_transformed (Callable) – 转换后的函数。
transforms (tuple[tuple[Callable, tuple[Hashable, ...]], ...]) – 由 (gen, gen_static_args) 元组组成的元组,表示要应用于 f 的转换。其中 gen 是一个生成器函数,gen_static_args 是生成器的静态参数元组。有关生成器的预期行为,请参阅本模块开头的描述。
stores (tuple[Store | EqualStore | None, ...]) – 用于 transforms 辅助输出的 out_store 列表。
params (tuple[tuple[str, Any], ...]) – 由 (name, param) 元组组成的元组,表示要作为关键字参数传递给 f 的额外参数,以及转换后的关键字参数。
in_type (core.InputType | None) – 可选的输入类型
debug_info (DebugInfo) – 关于被包装函数的调试信息。
方法
__init__(f, f_transformed, transforms, ...)call_wrapped(*args, **kwargs)调用转换后的函数
populate_stores(stores)将 stores 中的值复制到 self.stores。
replace_debug_info(dbg)with_unknown_names()wrap(gen, gen_static_args, out_store)添加另一个转换及其存储。
属性
ff_transformedtransformsstoresparamsin_typedebug_info