jax.export.Exported#

class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[source]#

已降低到 StableHLO 的 JAX 函数。

参数:
  • fun_name (str)

  • in_tree (tree_util.PyTreeDef)

  • in_avals (tuple[core.ShapedArray, ...])

  • out_tree (tree_util.PyTreeDef)

  • out_avals (tuple[core.ShapedArray, ...])

  • in_shardings_hlo (tuple[HloSharding | None, ...])

  • out_shardings_hlo (tuple[HloSharding | None, ...])

  • nr_devices (int)

  • platforms (tuple[str, ...])

  • ordered_effects (tuple[effects.Effect, ...])

  • unordered_effects (tuple[effects.Effect, ...])

  • disabled_safety_checks (Sequence[DisabledSafetyCheck])

  • mlir_module_serialized (bytes)

  • calling_convention_version (int)

  • module_kept_var_idx (tuple[int, ...])

  • uses_global_constants (bool)

  • _get_vjp (Callable[[Exported], Exported] | None)

fun_name#

导出函数的名称,用于错误消息。

类型:

str

in_tree#

一个 PyTreeDef,描述了已降低的 JAX 函数的元组 (args, kwargs)。实际降低不依赖于 in_tree,但可用于使用相同的参数结构调用导出函数。

类型:

tree_util.PyTreeDef

in_avals#

输入抽象值的扁平元组。形状中可能包含维度表达式。

类型:

tuple[core.ShapedArray, …]

out_tree#

描述已降低的 JAX 函数结果的 PyTreeDef。

类型:

tree_util.PyTreeDef

out_avals#

输出抽象值的扁平元组。形状中可能包含维度表达式,维度变量位于 in_avals 中。

类型:

tuple[core.ShapedArray, …]

in_shardings_hlo#

扁平化的输入分片,一个与 in_avals 长度相同的序列。 None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关将这些转换为可用于 JAX API 的分片规范的方法,请参见 in_shardings_jax

类型:

tuple[HloSharding | None, …]

out_shardings_hlo#

扁平化的输出分片,一个与 out_avals 长度相同的序列。 None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关将这些转换为可用于 JAX API 的分片规范的方法,请参见 out_shardings_jax

类型:

tuple[HloSharding | None, …]

nr_devices#

模块已降低到的设备数量。

类型:

int

platforms#

包含要导出函数的平台的元组。JAX 中的平台集是开放式的;用户可以添加平台。JAX 内置平台是:‘tpu’,‘cpu’,‘cuda’,‘rocm’。请参见 https://jax.net.cn/en/latest/export/export.html#cross-platform-and-multi-platform-export.

类型:

tuple[str, …]

ordered_effects#

序列化模块中存在的排序效果。此效果从序列化版本 9 开始存在。有关在存在排序效果的情况下调用约定的信息,请参见 https://jax.net.cn/en/latest/export/export.html#module-calling-convention.

类型:

tuple[effects.Effect, …]

unordered_effects#

序列化模块中存在的无序效果。此效果从序列化版本 9 开始存在。

类型:

tuple[effects.Effect, …]

mlir_module_serialized#

序列化后的已降低 VHLO 模块。

类型:

bytes

calling_convention_version#

导出模块的调用约定的版本号。有关更多版本信息,请参见 https://jax.net.cn/en/latest/export/export.html#calling-convention-versions.

类型:

int

module_kept_var_idx#

必须传递给模块的 in_avals 中的参数的排序索引。其他参数已被删除,因为它们未使用。

类型:

tuple[int, …]

uses_global_constants#

mlir_module_serialized 是否使用形状多态性或多平台导出。这可能是因为 in_avals 包含维度变量,或者由于对包含维度变量或平台索引参数的 Exported 模块的内部调用。此类模块需要在 XLA 编译之前进行形状细化。

类型:

bool

disabled_safety_checks#

在导出时已禁用的安全检查描述符列表。请参见 DisabledSafetyCheck 的文档字符串。

类型:

Sequence[DisabledSafetyCheck]

_get_vjp#

一个可选函数,它获取当前的导出函数并返回导出的 VJP 函数。VJP 函数接受一个扁平的列表参数,从基本参数开始,然后是每个基本输出的切线参数。它返回一个元组,其中包含与扁平化基本输入相对应的切线。

类型:

Callable[[Exported], Exported] | None

请参见 [对 mlir_module 的调用约定的描述] (https://jax.net.cn/en/latest/export/export.html#module-calling-convention)。

__init__(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)#
参数:
  • fun_name (str)

  • in_tree (tree_util.PyTreeDef)

  • in_avals (tuple[core.ShapedArray, ...])

  • out_tree (tree_util.PyTreeDef)

  • out_avals (tuple[core.ShapedArray, ...])

  • in_shardings_hlo (tuple[HloSharding | None, ...])

  • out_shardings_hlo (tuple[HloSharding | None, ...])

  • nr_devices (int)

  • platforms (tuple[str, ...])

  • ordered_effects (tuple[effects.Effect, ...])

  • unordered_effects (tuple[effects.Effect, ...])

  • disabled_safety_checks (Sequence[DisabledSafetyCheck])

  • mlir_module_serialized (bytes)

  • calling_convention_version (int)

  • module_kept_var_idx (tuple[int, ...])

  • uses_global_constants (bool)

  • _get_vjp (Callable[[Exported], Exported] | None)

返回类型:

None

方法

__init__(fun_name, in_tree, in_avals, ...)

call(*args, **kwargs)

has_vjp()

返回此 Exported 是否支持 VJP。

in_shardings_jax(mesh)

创建与 self.in_shardings_hlo 相对应的 Shardings。

mlir_module()

out_shardings_jax(mesh)

创建与 self.out_shardings_hlo 相对应的 Shardings。

serialize([vjp_order])

序列化 Exported。

vjp()

获取导出的 VJP。

属性

in_shardings

lowering_platforms

已弃用。

mlir_module_serialization_version

已弃用。

out_shardings

uses_shape_polymorphism

已弃用。

fun_name

in_tree

in_avals

out_tree

out_avals

in_shardings_hlo

out_shardings_hlo

nr_devices

platforms

ordered_effects

unordered_effects

disabled_safety_checks

mlir_module_serialized

calling_convention_version

module_kept_var_idx

uses_global_constants