jax.stages 模块#
编译执行过程的阶段接口。
JAX 的编译时(just-in-time)转换,如 jax.jit 和 jax.pmap,也支持一种通用的提前(ahead of time)显式降低(lowering)和编译的方式。本模块定义了表示此过程阶段的类型。
更多信息,请参阅 AOT 教程。
类#
- class jax.stages.Wrapped(*args, **kwargs)[source]#
一个准备好进行追踪(tracing)、降低(lowering)和编译(compilation)的函数。
此协议反映了
jax.jit等函数的输出。调用它将执行 JIT(即时)降低、编译和执行。它也可以在编译前显式降低,并将结果在执行前编译。
- class jax.stages.Traced(lfg, params, in_tree, out_tree, num_consts)[source]#
一个针对参数类型和值特化的函数追踪形式。
已追踪的计算已准备好进行降低。此类携带了追踪表示以及稍后降低、编译和执行它所需的剩余信息。
- class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False, in_types=None, out_types=None)[source]#
一个针对参数类型和值特化的函数降低表示。
降低(lowering)是一个准备好进行编译的计算。此类携带了降低表示以及稍后编译和执行它所需的剩余信息。它还为查询 JAX 各个降低路径(
jit(),pmap()等)的已降低计算的属性提供了通用 API。- 参数:
lowering (Lowering)
args_info (Any)
out_tree (tree_util.PyTreeDef)
no_kwargs (bool)
- as_text(dialect=None, *, debug_info=False)[source]#
此降低的可读文本表示。
用于可视化和调试目的。这不一定是一个有效的或可靠的序列化。如果您想要可靠且可移植的序列化,请使用 jax.export。
- compiler_ir(dialect=None)[source]#
此降低的任意对象表示。
用于调试目的。这不是一个有效或可靠的序列化。输出在不同调用之间没有一致性保证。如果您想要可靠且可移植的序列化,请使用 jax.export。
如果不可用(例如,取决于后端、编译器或运行时),则返回
None。- 参数:
dialect (str | None) – 可选字符串,指定降低方言(例如,“stablehlo”或“hlo”)。
- 返回类型:
Any | None
- class jax.stages.Compiled(executable, const_args, args_info, out_tree, no_kwargs=False, in_types=None, out_types=None)[source]#
一个针对类型/值特化的函数编译表示。
已编译的计算与一个可执行文件以及执行它所需的剩余信息相关联。它还为查询 JAX 各种编译路径和后端已编译计算的属性提供了一个通用 API。
- 参数:
const_args (list[ArrayLike])
args_info (Any)
out_tree (tree_util.PyTreeDef)
- as_text()[source]#
此可执行文件的可读文本表示。
用于可视化和调试目的。这不是一个有效或可靠的序列化。
如果不可用(例如,取决于后端、编译器或运行时),则返回
None。- 返回类型:
str | None
- cost_analysis()[source]#
执行成本估算的摘要。
用于可视化和调试目的。此方法输出的对象是一个简单的、易于打印或序列化的数据结构(例如,嵌套的字典、列表和具有数值叶子的元组)。然而,其结构可能是任意的:它在 JAX 和 jaxlib 的版本之间,甚至在调用之间都可能不一致。
如果不可用(例如,取决于后端、编译器或运行时),则返回
None。- 返回类型:
Any | None