jax.extend.linear_util.WrappedFun#

class jax.extend.linear_util.WrappedFun(f, f_transformed, transforms, stores, params, in_type, debug_info)[source]#

表示要应用 transforms 的函数 f

参数:
  • f (Callable) – 要转换的函数。

  • f_transformed (Callable) – 转换后的函数。

  • transforms (tuple[tuple[Callable, tuple[Hashable, ...]], ...]) – 代表要应用于 f 的转换的 (gen, gen_static_args) 元组。此处 gen 是一个生成器函数,gen_static_args 是生成器的静态参数元组。有关生成器预期行为的描述,请参见本模块的开头。

  • stores (tuple[Store | EqualStore | None, ...]) – transforms 的辅助输出的 out_store 列表。

  • params (tuple[tuple[str, Any], ...]) – (name, param) 元组,表示要作为关键字参数传递给 f 以及转换后的关键字参数的额外参数。

  • in_type (core.InputType | None) – 可选输入类型

  • debug_info (DebugInfo) – 关于被包装函数的调试信息。

__init__(f, f_transformed, transforms, stores, params, in_type, debug_info)[source]#
参数:
  • f (Callable)

  • f_transformed (Callable)

  • transforms (tuple[tuple[Callable, tuple[Hashable, ...]], ...])

  • stores (tuple[Store | EqualStore | None, ...])

  • params (tuple[tuple[str, Hashable], ...])

  • in_type (core.InputType | None)

  • debug_info (DebugInfo)

方法

__init__(f, f_transformed, transforms, ...)

call_wrapped(*args, **kwargs)

调用转换后的函数

populate_stores(stores)

将值从 stores 复制到 self.stores 中。

wrap(gen, gen_static_args, out_store)

添加另一个转换及其存储。

属性

f

f_transformed

transforms

stores

params

in_type

debug_info