使用 set_xla_metadata 附加 XLA 元数据#

摘要: set_xla_metadata 允许您将元数据附加到 JAX 代码中的操作。此元数据将作为 frontend_attributes 传递给 XLA 编译器,可用于启用编译器级别的调试工具,例如 XLA-TPU 调试器。

您可以通过三种方式使用它:

  1. 通过包装其输出值来标记单个操作

  2. 使用上下文管理器标记操作块

  3. 使用装饰器标记函数中的所有操作

警告: set_xla_metadata 是一项实验性功能,其 API 可能会发生变化。

什么是 XLA 元数据?#

当 JAX 转换和编译您的代码时,它最终会生成一个 XLA(加速线性代数)计算图。此图中的每个操作都可以关联元数据,特别是 frontend_attributes。此元数据不会更改操作的数值结果,但可用于向编译器或运行时指示特殊行为。

set_xla_metadata 提供了一种直接从 JAX 代码附加此元数据的方法。这是低级调试和性能分析的强大功能。

用法#

标记单个操作#

标记单个操作可以精确控制您希望检查计算的哪些部分。为此,请使用 set_xla_metadata 包装操作的输出(值)。在包装包含多个操作的函数时,仅会标记该函数的最后一个操作。

import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata

# Tagging an individual operation
def value_tagging(x):
  y = jnp.sin(x)
  z = jnp.cos(x)
  return set_xla_metadata(y * z, breakpoint=True)

print(jax.jit(value_tagging).lower(1.0).as_text("hlo"))

结果是

ENTRY main.5 {
  x.1 = f32[] parameter(0)
  sin.2 = f32[] sine(x.1)
  cos.3 = f32[] cosine(x.1)
  ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={breakpoint="true"}
}

使用上下文管理器或装饰器标记代码块#

如果您想将相同的元数据应用于更大的代码段,可以使用 set_xla_metadata 作为上下文管理器。`with` 块内的所有 JAX 操作都将附加指定的元数据。

import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata

# Tagging a block of code
def context_tagging(x):
  with set_xla_metadata(_xla_log=True):
    y = jnp.sin(x)
    z = jnp.cos(y)
    return y * z

print(jax.jit(context_tagging).lower(1.0).as_text("hlo"))

结果是

ENTRY main.5 {
  x.1 = f32[] parameter(0)
  sin.2 = f32[] sine(x.1), frontend_attributes={_xla_log="true"}
  cos.3 = f32[] cosine(sin.2), frontend_attributes={_xla_log="true"}
  ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={_xla_log="true"}
}

如果您想标记函数中的所有操作,也可以使用 set_xla_metadata 作为装饰器。

import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata

# Tagging with a decorator
@set_xla_metadata(_xla_log=True)
@jax.jit
def decorator_tagging(x):
  y = jnp.sin(x)
  z = jnp.cos(y)
  return y * z

print(decorator_tagging.lower(1.0).as_text("hlo"))

这将产生与上面相同的 HLO。

与 JAX 转换的交互#

set_xla_metadata 根据用例使用 XlaMetadataContextManager 或 JAX primitive,并且与 `jit`、`vmap` 和 `grad` 等 JAX 转换兼容。

  • `vmap`:当您 `vmap` 包含 `set_xla_metadata` 的函数时,元数据将应用于所有相关的批量操作。

  • grad:

    1. 当使用 **上下文管理器** `with set_xla_metadata(...):` 标记操作块时,元数据将应用于其中操作的前向和后向传递。

    2. 使用 `set_xla_metadata()` 标记 **单个操作** 目前仅适用于函数的正向传播。要标记反向传播(即梯度计算)生成的单个操作,可以使用简单的 `custom_vjp`。

      import jax
      import jax.numpy as jnp
      from jax.experimental.xla_metadata import set_xla_metadata
      
      def fn(x):
          y = jnp.sin(x)
          z = jnp.cos(x)
          return y * z
      
      metadata = {"example": "grad_tagging"}
      
      # --- Define Custom VJP to tag gradients ---
      @jax.custom_vjp
      def wrapped_fn(x):
          return fn(x)
      
      def fwd(*args):
          primal_out, vjp_fn = jax.vjp(fn, *args)
          return primal_out, vjp_fn
      
      def bwd(vjp_fn, cts_in):
          cts_out = vjp_fn(cts_in)
          cts_out = set_xla_metadata(cts_out, **metadata)
          return cts_out
      
      wrapped_fn.defvjp(fwd, bwd)
      # ------
      
      print(jax.jit(jax.grad(wrapped_fn)).lower(jnp.array(3.0)).as_text("hlo"))
      

      结果是

      ENTRY main.10 {
        x.1 = f32[] parameter(0)
        sin.2 = f32[] sine(x.1)
        neg.6 = f32[] negate(sin.2)
        sin.5 = f32[] sine(x.1)
        mul.7 = f32[] multiply(neg.6, sin.5)
        cos.4 = f32[] cosine(x.1)
        cos.3 = f32[] cosine(x.1)
        mul.8 = f32[] multiply(cos.4, cos.3)
        ROOT add_any.9 = f32[] add(mul.7, mul.8), frontend_attributes={example="grad_tagging"}
      }
      

set_xla_metadata 的优点和局限性#

优点#

  • 灵活控制:允许您针对单个操作或操作块。

  • 非侵入性:不改变程序的数值输出或融合行为。

  • 支持强大的工具:解锁了在编译器级别进行复杂调试和分析的潜力。

局限性#

  • 属性可能丢失:虽然旨在使 XLA 元数据在整个转换和 HLO 优化过程中得以维护,但在某些边缘情况下,元数据可能会丢失。

  • 仅限正向传播:目前,在标记反向传播中的 **单个操作** 时,元数据不会自动传播到梯度。在这种情况下,必须使用 `custom_vjp` 来标记梯度。请参阅上文示例。

  • 易变:`set_xla_metadata` 是一项实验性功能,其 API 可能会发生变化。