jax.stages 模块#

编译执行过程各个阶段的接口。

JAX 转换(如 jax.jitjax.pmap)在执行时进行即时编译,但也支持一种常见的显式降低和提前编译的方法。此模块定义了代表此过程各个阶段的类型。

更多信息,请参阅 AOT 演练

#

class jax.stages.Wrapped(*args, **kwargs)[source]#

一个准备好进行跟踪、降低和编译的函数。

此协议反映了 jax.jit 等函数的输出。调用它会进行 JIT(即时)降低、编译和执行。它也可以在编译前显式降低,结果可以在执行前编译。

__call__(*args, **kwargs)[source]#

执行包装函数,根据需要进行降低和编译。

lower(*args, **kwargs)[source]#

为给定参数显式降低此函数。

这是 self.trace(*args, **kwargs).lower() 的快捷方式。

降低后的函数会脱离 Python 并转换为编译器的输入语言,这可能以后端相关的方式进行。它已准备好进行编译但尚未编译。

返回:

表示降低的 Lowered 实例。

返回类型:

Lowered

trace(*args, **kwargs)[source]#

为给定参数显式跟踪此函数。

跟踪后的函数会脱离 Python 并转换为 jaxpr。它已准备好进行降低但尚未降低。

返回:

表示跟踪的 Traced 实例。

返回类型:

Traced

class jax.stages.Traced(jaxpr, args_info, fun_name, out_tree, lower_callable, args_flat=None, arg_names=None, num_consts=0, params_out_shardings=None)[source]#

针对参数类型和值进行专门化的函数跟踪形式。

跟踪后的计算已准备好进行降低。此类包含跟踪表示以及后续降低、编译和执行所需的信息。

参数:
lower(*, lowering_platforms=None, _private_parameters=None)[source]#

降低为编译器输入,返回一个 Lowered 实例。

参数:
  • lowering_platforms (tuple[str, ...] | None)

  • _private_parameters (mlir.LoweringParameters | None)

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[source]#

针对参数类型和值进行专门化的函数降低。

降低是已准备好编译的计算。此类包含一个降低以及后续编译和执行所需的信息。它还提供了一个通用 API,用于查询 JAX 各种降低路径(如 jit()pmap() 等)中降低后的计算属性。

参数:
  • lowering (Lowering)

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None, *, debug_info=False)[source]#

此降低的可读文本表示。

用于可视化和调试目的。这不一定是有效或可靠的序列化。如果需要可靠和可移植的序列化,请使用 jax.export

参数:
  • dialect (str | None) – 可选字符串,指定降低方言(例如“stablehlo”或“hlo”)。

  • debug_info (bool) – 是否包含调试信息,例如源位置。

返回类型:

str

compile(compiler_options=None, *, device_assignment=None)[source]#

编译,返回一个对应的 Compiled 实例。

参数:
  • compiler_options (CompilerOptions | None)

  • device_assignment (tuple[xc.Device, ...] | None)

返回类型:

Compiled

compiler_ir(dialect=None)[source]#

此降低的任意对象表示。

用于调试目的。这不是有效或可靠的序列化。输出不能保证在不同调用之间保持一致性。如果需要可靠和可移植的序列化,请使用 jax.export

如果不可用,则返回 None,例如基于后端、编译器或运行时。

参数:

dialect (str | None) – 可选字符串,指定降低方言(例如“stablehlo”或“hlo”)。

返回类型:

Any | None

cost_analysis()[source]#

执行成本估算的摘要。

用于可视化和调试目的。由此输出的对象是一些可以轻松打印或序列化(例如,带有数字叶子的嵌套字典、列表和元组)的简单数据结构。但是,其结构可以是任意的:它在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能不一致。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[source]#

(位置参数,关键字参数)对的树结构。

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[source]#

针对类型/值进行专门化的函数的编译表示。

编译后的计算与可执行文件以及执行它所需的其余信息相关联。它还提供了一个通用 API,用于查询 JAX 各种编译路径和后端中编译后计算的属性。

参数:
  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

__call__(*args, **kwargs)[source]#

将自身作为函数调用。

as_text()[source]#

此可执行文件的可读文本表示。

用于可视化和调试目的。这不是有效或可靠的序列化。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

str | None

cost_analysis()[source]#

执行成本估算的摘要。

用于可视化和调试目的。由此输出的对象是一些可以轻松打印或序列化(例如,带有数字叶子的嵌套字典、列表和元组)的简单数据结构。但是,其结构可以是任意的:它在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能不一致。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[source]#

(位置参数,关键字参数)对的树结构。

memory_analysis()[source]#

估计内存需求的摘要。

用于可视化和调试目的。由此输出的对象是一些可以轻松打印或序列化(例如,带有数字叶子的嵌套字典、列表和元组)的简单数据结构。但是,其结构可以是任意的:它在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能不一致。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None

runtime_executable()[source]#

此可执行文件的任意对象表示。

用于调试目的。这不是有效或可靠的序列化。输出不能保证在不同调用之间保持一致性。

如果不可用,则返回 None,例如基于后端、编译器或运行时。

返回类型:

Any | None