jax.export.Exported#
- class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[source]#
已降低到 StableHLO 的 JAX 函数。
- 参数:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray, ...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray, ...])
in_shardings_hlo (tuple[HloSharding | None, ...])
out_shardings_hlo (tuple[HloSharding | None, ...])
nr_devices (int)
ordered_effects (tuple[effects.Effect, ...])
unordered_effects (tuple[effects.Effect, ...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- in_tree#
一个 PyTreeDef,描述了已降低的 JAX 函数的元组 (args, kwargs)。实际降低不依赖于 in_tree,但可用于使用相同的参数结构调用导出函数。
- 类型:
tree_util.PyTreeDef
- out_tree#
描述已降低的 JAX 函数结果的 PyTreeDef。
- 类型:
tree_util.PyTreeDef
- in_shardings_hlo#
扁平化的输入分片,一个与 in_avals 长度相同的序列。 None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关将这些转换为可用于 JAX API 的分片规范的方法,请参见 in_shardings_jax。
- 类型:
tuple[HloSharding | None, …]
- out_shardings_hlo#
扁平化的输出分片,一个与 out_avals 长度相同的序列。 None 表示未指定分片。请注意,这些不包括网格或网格中使用的实际设备。有关将这些转换为可用于 JAX API 的分片规范的方法,请参见 out_shardings_jax。
- 类型:
tuple[HloSharding | None, …]
- platforms#
包含要导出函数的平台的元组。JAX 中的平台集是开放式的;用户可以添加平台。JAX 内置平台是:‘tpu’,‘cpu’,‘cuda’,‘rocm’。请参见 https://jax.net.cn/en/latest/export/export.html#cross-platform-and-multi-platform-export.
- ordered_effects#
序列化模块中存在的排序效果。此效果从序列化版本 9 开始存在。有关在存在排序效果的情况下调用约定的信息,请参见 https://jax.net.cn/en/latest/export/export.html#module-calling-convention.
- 类型:
tuple[effects.Effect, …]
- calling_convention_version#
导出模块的调用约定的版本号。有关更多版本信息,请参见 https://jax.net.cn/en/latest/export/export.html#calling-convention-versions.
- 类型:
- uses_global_constants#
mlir_module_serialized 是否使用形状多态性或多平台导出。这可能是因为 in_avals 包含维度变量,或者由于对包含维度变量或平台索引参数的 Exported 模块的内部调用。此类模块需要在 XLA 编译之前进行形状细化。
- 类型:
- disabled_safety_checks#
在导出时已禁用的安全检查描述符列表。请参见 DisabledSafetyCheck 的文档字符串。
- 类型:
Sequence[DisabledSafetyCheck]
- _get_vjp#
一个可选函数,它获取当前的导出函数并返回导出的 VJP 函数。VJP 函数接受一个扁平的列表参数,从基本参数开始,然后是每个基本输出的切线参数。它返回一个元组,其中包含与扁平化基本输入相对应的切线。
请参见 [对 mlir_module 的调用约定的描述] (https://jax.net.cn/en/latest/export/export.html#module-calling-convention)。
- __init__(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)#
- 参数:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray, ...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray, ...])
in_shardings_hlo (tuple[HloSharding | None, ...])
out_shardings_hlo (tuple[HloSharding | None, ...])
nr_devices (int)
ordered_effects (tuple[effects.Effect, ...])
unordered_effects (tuple[effects.Effect, ...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- 返回类型:
None
方法
__init__
(fun_name, in_tree, in_avals, ...)call
(*args, **kwargs)has_vjp
()返回此 Exported 是否支持 VJP。
in_shardings_jax
(mesh)创建与 self.in_shardings_hlo 相对应的 Shardings。
mlir_module
()out_shardings_jax
(mesh)创建与 self.out_shardings_hlo 相对应的 Shardings。
serialize
([vjp_order])序列化 Exported。
vjp
()获取导出的 VJP。
属性
in_shardings
lowering_platforms
已弃用。
mlir_module_serialization_version
已弃用。
out_shardings
uses_shape_polymorphism
已弃用。