jax.stages 模块#

编译执行过程的阶段接口。

JAX 的编译时(just-in-time)转换,如 jax.jitjax.pmap,也支持一种通用的提前(ahead of time)显式降低(lowering)和编译的方式。本模块定义了表示此过程阶段的类型。

更多信息,请参阅 AOT 教程

#

class jax.stages.Wrapped(*args, **kwargs)[source]#

一个准备好进行追踪(tracing)、降低(lowering)和编译(compilation)的函数。

此协议反映了 jax.jit 等函数的输出。调用它将执行 JIT(即时)降低、编译和执行。它也可以在编译前显式降低,并将结果在执行前编译。

__call__(*args, **kwargs)[source]#

执行被包装的函数,根据需要进行降低和编译。

lower(*args, **kwargs)[source]#

显式地为给定参数降低此函数。

这是 self.trace(*args, **kwargs).lower() 的快捷方式。

一个已降低的函数已被迁出 Python,并翻译成编译器的输入语言,可能以一种依赖于后端的方式。它已准备好进行编译,但尚未编译。

返回:

一个表示降低的 Lowered 实例。

返回类型:

Lowered

trace(*args, **kwargs)[source]#

显式地为给定参数追踪此函数。

一个已追踪的函数已被迁出 Python,并翻译成 jaxpr。它已准备好进行降低,但尚未降低。

返回:

一个表示追踪的 Traced 实例。

返回类型:

Traced

class jax.stages.Traced(lfg, params, in_tree, out_tree, num_consts)[source]#

一个针对参数类型和值特化的函数追踪形式。

已追踪的计算已准备好进行降低。此类携带了追踪表示以及稍后降低、编译和执行它所需的剩余信息。

lower(*, lowering_platforms=None, _private_parameters=None)[source]#

降低到编译器输入,返回一个 Lowered 实例。

参数:
  • lowering_platforms (tuple[str, ...] | None)

  • _private_parameters (mlir.LoweringParameters | None)

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False, in_types=None, out_types=None)[source]#

一个针对参数类型和值特化的函数降低表示。

降低(lowering)是一个准备好进行编译的计算。此类携带了降低表示以及稍后编译和执行它所需的剩余信息。它还为查询 JAX 各个降低路径(jit(), pmap() 等)的已降低计算的属性提供了通用 API。

参数:
  • lowering (Lowering)

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None, *, debug_info=False)[source]#

此降低的可读文本表示。

用于可视化和调试目的。这不一定是一个有效的或可靠的序列化。如果您想要可靠且可移植的序列化,请使用 jax.export

参数:
  • dialect (str | None) – 可选字符串,指定降低方言(例如,“stablehlo”或“hlo”)。

  • debug_info (bool) – 是否包含调试信息,例如源位置。

返回类型:

str

compile(compiler_options=None, *, device_assignment=None)[source]#

编译,返回相应的 Compiled 实例。

参数:
  • compiler_options (CompilerOptions | None)

  • device_assignment (tuple[xc.Device, ...] | None)

返回类型:

Compiled

compiler_ir(dialect=None)[source]#

此降低的任意对象表示。

用于调试目的。这不是一个有效或可靠的序列化。输出在不同调用之间没有一致性保证。如果您想要可靠且可移植的序列化,请使用 jax.export

如果不可用(例如,取决于后端、编译器或运行时),则返回 None

参数:

dialect (str | None) – 可选字符串,指定降低方言(例如,“stablehlo”或“hlo”)。

返回类型:

Any | None

cost_analysis()[source]#

执行成本估算的摘要。

用于可视化和调试目的。此方法输出的对象是一个简单的、易于打印或序列化的数据结构(例如,嵌套的字典、列表和具有数值叶子的元组)。然而,其结构可能是任意的:它在 JAX 和 jaxlib 的版本之间,甚至在调用之间都可能不一致。

如果不可用(例如,取决于后端、编译器或运行时),则返回 None

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[source]#

位置参数和关键字参数对的树结构。

class jax.stages.Compiled(executable, const_args, args_info, out_tree, no_kwargs=False, in_types=None, out_types=None)[source]#

一个针对类型/值特化的函数编译表示。

已编译的计算与一个可执行文件以及执行它所需的剩余信息相关联。它还为查询 JAX 各种编译路径和后端已编译计算的属性提供了一个通用 API。

参数:
  • const_args (list[ArrayLike])

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

__call__(*args, **kwargs)[source]#

将 self 调用为一个函数。

as_text()[source]#

此可执行文件的可读文本表示。

用于可视化和调试目的。这不是一个有效或可靠的序列化。

如果不可用(例如,取决于后端、编译器或运行时),则返回 None

返回类型:

str | None

cost_analysis()[source]#

执行成本估算的摘要。

用于可视化和调试目的。此方法输出的对象是一个简单的、易于打印或序列化的数据结构(例如,嵌套的字典、列表和具有数值叶子的元组)。然而,其结构可能是任意的:它在 JAX 和 jaxlib 的版本之间,甚至在调用之间都可能不一致。

如果不可用(例如,取决于后端、编译器或运行时),则返回 None

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[source]#

位置参数和关键字参数对的树结构。

memory_analysis()[source]#

估计内存需求的摘要。

用于可视化和调试目的。此方法输出的对象是一个简单的、易于打印或序列化的数据结构(例如,嵌套的字典、列表和具有数值叶子的元组)。然而,其结构可能是任意的:它在 JAX 和 jaxlib 的版本之间,甚至在调用之间都可能不一致。

如果不可用(例如,取决于后端、编译器或运行时),则返回 None

返回类型:

Any | None

runtime_executable()[source]#

此可执行文件的任意对象表示。

用于调试目的。这不是一个有效或可靠的序列化。输出在不同调用之间没有一致性保证。

如果不可用(例如,取决于后端、编译器或运行时),则返回 None

返回类型:

Any | None