使用 set_xla_metadata 附加 XLA 元数据#
摘要: set_xla_metadata
允许您将元数据附加到 JAX 代码中的操作。此元数据将作为 frontend_attributes
传递给 XLA 编译器,可用于启用编译器级别的调试工具,例如 XLA-TPU 调试器。
您可以通过三种方式使用它:
通过包装其输出值来标记单个操作
使用上下文管理器标记操作块
使用装饰器标记函数中的所有操作
警告: set_xla_metadata
是一项实验性功能,其 API 可能会发生变化。
什么是 XLA 元数据?#
当 JAX 转换和编译您的代码时,它最终会生成一个 XLA(加速线性代数)计算图。此图中的每个操作都可以关联元数据,特别是 frontend_attributes
。此元数据不会更改操作的数值结果,但可用于向编译器或运行时指示特殊行为。
set_xla_metadata
提供了一种直接从 JAX 代码附加此元数据的方法。这是低级调试和性能分析的强大功能。
用法#
标记单个操作#
标记单个操作可以精确控制您希望检查计算的哪些部分。为此,请使用 set_xla_metadata
包装操作的输出(值)。在包装包含多个操作的函数时,仅会标记该函数的最后一个操作。
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
# Tagging an individual operation
def value_tagging(x):
y = jnp.sin(x)
z = jnp.cos(x)
return set_xla_metadata(y * z, breakpoint=True)
print(jax.jit(value_tagging).lower(1.0).as_text("hlo"))
结果是
ENTRY main.5 {
x.1 = f32[] parameter(0)
sin.2 = f32[] sine(x.1)
cos.3 = f32[] cosine(x.1)
ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={breakpoint="true"}
}
使用上下文管理器或装饰器标记代码块#
如果您想将相同的元数据应用于更大的代码段,可以使用 set_xla_metadata
作为上下文管理器。`with` 块内的所有 JAX 操作都将附加指定的元数据。
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
# Tagging a block of code
def context_tagging(x):
with set_xla_metadata(_xla_log=True):
y = jnp.sin(x)
z = jnp.cos(y)
return y * z
print(jax.jit(context_tagging).lower(1.0).as_text("hlo"))
结果是
ENTRY main.5 {
x.1 = f32[] parameter(0)
sin.2 = f32[] sine(x.1), frontend_attributes={_xla_log="true"}
cos.3 = f32[] cosine(sin.2), frontend_attributes={_xla_log="true"}
ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={_xla_log="true"}
}
如果您想标记函数中的所有操作,也可以使用 set_xla_metadata
作为装饰器。
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
# Tagging with a decorator
@set_xla_metadata(_xla_log=True)
@jax.jit
def decorator_tagging(x):
y = jnp.sin(x)
z = jnp.cos(y)
return y * z
print(decorator_tagging.lower(1.0).as_text("hlo"))
这将产生与上面相同的 HLO。
与 JAX 转换的交互#
set_xla_metadata
根据用例使用 XlaMetadataContextManager
或 JAX primitive
,并且与 `jit`、`vmap` 和 `grad` 等 JAX 转换兼容。
`vmap`:当您 `vmap` 包含 `set_xla_metadata` 的函数时,元数据将应用于所有相关的批量操作。
grad
:当使用 **上下文管理器** `with set_xla_metadata(...):` 标记操作块时,元数据将应用于其中操作的前向和后向传递。
使用 `set_xla_metadata()` 标记 **单个操作** 目前仅适用于函数的正向传播。要标记反向传播(即梯度计算)生成的单个操作,可以使用简单的 `custom_vjp`。
import jax import jax.numpy as jnp from jax.experimental.xla_metadata import set_xla_metadata def fn(x): y = jnp.sin(x) z = jnp.cos(x) return y * z metadata = {"example": "grad_tagging"} # --- Define Custom VJP to tag gradients --- @jax.custom_vjp def wrapped_fn(x): return fn(x) def fwd(*args): primal_out, vjp_fn = jax.vjp(fn, *args) return primal_out, vjp_fn def bwd(vjp_fn, cts_in): cts_out = vjp_fn(cts_in) cts_out = set_xla_metadata(cts_out, **metadata) return cts_out wrapped_fn.defvjp(fwd, bwd) # ------ print(jax.jit(jax.grad(wrapped_fn)).lower(jnp.array(3.0)).as_text("hlo"))
结果是
ENTRY main.10 { x.1 = f32[] parameter(0) sin.2 = f32[] sine(x.1) neg.6 = f32[] negate(sin.2) sin.5 = f32[] sine(x.1) mul.7 = f32[] multiply(neg.6, sin.5) cos.4 = f32[] cosine(x.1) cos.3 = f32[] cosine(x.1) mul.8 = f32[] multiply(cos.4, cos.3) ROOT add_any.9 = f32[] add(mul.7, mul.8), frontend_attributes={example="grad_tagging"} }
set_xla_metadata 的优点和局限性#
优点#
灵活控制:允许您针对单个操作或操作块。
非侵入性:不改变程序的数值输出或融合行为。
支持强大的工具:解锁了在编译器级别进行复杂调试和分析的潜力。
局限性#
属性可能丢失:虽然旨在使 XLA 元数据在整个转换和 HLO 优化过程中得以维护,但在某些边缘情况下,元数据可能会丢失。
仅限正向传播:目前,在标记反向传播中的 **单个操作** 时,元数据不会自动传播到梯度。在这种情况下,必须使用 `custom_vjp` 来标记梯度。请参阅上文示例。
易变:`set_xla_metadata` 是一项实验性功能,其 API 可能会发生变化。