Omnistaging#

mattjj@ 2020 年 9 月 25 日

这更像是一个升级指南,而不是一个设计文档。

目录#

tl;dr#

发生了什么?#

对 JAX 的跟踪基础设施的更改,称为“omnistaging”(jax-ml/jax#3370)在 jax==0.2.0 中已开启。此更改改进了内存性能、跟踪执行时间,并简化了 jax 内部机制,但可能导致一些现有代码中断。中断通常是错误代码的结果,因此从长远来看,最好修复错误,但也可以禁用 omnistaging 作为临时解决方法。我们很乐意帮助您修复!

我如何知道 omnistaging 是否破坏了我的代码?#

判断 omnistaging 是否是罪魁祸首的最简单方法是禁用 omnistaging,看看问题是否消失。请参阅下面的当 omnistaging 打开时,可能出现哪些问题?部分。

我现在如何禁用 omnistaging?#

注意:这适用于 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及更高版本中无法禁用 omnistaging

暂时可以通过以下方式禁用 omnistaging

  1. 将 shell 环境变量 JAX_OMNISTAGING 设置为 falsey 值;

  2. 如果您的代码使用 absl 解析标志,则将布尔标志 jax_omnistaging 设置为 falsey 值;

  3. 在您的主文件的顶部附近使用此语句

jax.config.disable_omnistaging()

我如何修复由 omnistaging 暴露的错误?#

到目前为止,omnistaging 最常见的问题是使用 jax.numpy 计算形状值或其他跟踪时常量。有关快速示例,请参见下面的代码块,有关完整详细信息以及其他问题,请参见当 omnistaging 打开时,可能出现哪些问题?部分。

不要这样做

@jit
def f(x):
  input_size = jnp.prod(x.shape)
  if input_size > 100:
    ...

这样做

import numpy as np

@jit
def f(x):
  input_size = np.prod(x.shape)
  if input_size > 100:
    ...

与其将 jax.numpy 视为 numpy 的直接替代品,不如现在更好地考虑仅在您想在加速器(如 GPU)上执行计算时才使用 jax.numpy 操作。

什么是“omnistaging”,它为什么有用?#

Omnistaging 是 JAX 核心升级的名称,旨在将更多计算从逐操作 Python 阶段性输出到 XLA,并避免 jitpmap 和控制流原语中的任何“跟踪时常量折叠”。因此,omnistaging 通过减少跟踪期间的碎片和为 XLA 生成更少的编译时大常量,提高了 JAX 的内存性能(有时会显着提高)。它还可以通过消除跟踪时的逐操作执行来提高跟踪性能。此外,omnistaging 简化了 JAX 核心内部机制,修复了许多未解决的错误,并为即将到来的重要功能奠定了基础。

名称“omnistaging”意味着阶段性输出所有可能的内容。

玩具示例#

jitpmap 这样的 JAX 转换将计算阶段性输出到 XLA。也就是说,我们将它们应用于包含多个原始操作的函数,以便操作不是从 Python 逐个执行,而是全部作为一个端到端优化的 XLA 计算的一部分。

但是,究竟哪些操作会被阶段性输出?在 omnistaging 之前,JAX 仅基于数据依赖性阶段性输出计算。这是一个示例函数,以及它在 omnistaging 更改之前阶段性输出的 XLA HLO 程序

from jax import jit
import jax.numpy as jnp

@jit
def f(x):
  y = jnp.add(1, 1)
  return x * y

f(3)
ENTRY jit_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(parameter.1, constant.3)
  ROOT tuple.5 = (s32[]) tuple(multiply.4)
}

请注意,add 操作未阶段性输出。相反,我们只看到一个乘法。

这是从该函数 omnistaging 更改之后生成的 HLO

ENTRY jit_f.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(1)
  constant.4 = s32[] constant(1)
  add.5 = s32[] add(constant.3, constant.4)
  multiply.6 = s32[] multiply(parameter.1, add.5)
  ROOT tuple.7 = (s32[]) tuple(multiply.6)
}

稍微不那么玩具的例子#

这是一个不太玩具的例子,当我们想要创建布尔掩码时,在实践中可能会出现这种情况

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)

omnistaging 之前

ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

select 操作已阶段性输出,但用于构造常量 mask 的操作未阶段性输出。用于构造 mask 的操作不是阶段性输出,而是在 Python 跟踪时逐操作执行,XLA 只看到一个编译时常量 constant.1,表示 mask 的值。这很不幸,因为如果我们阶段性输出了用于构造 mask 的操作,XLA 可以将它们融合到 select 中,并完全避免实现结果。因此,我们最终浪费了内存,使用了可能很大的常量,浪费了时间来调度多个未融合的逐操作 XLA 计算,甚至可能导致内存碎片。

(对应于为 jnp.zeros_like(x) 构造零数组的 broadcast 已阶段性输出,因为 JAX 对来自 jax-ml/jax#1668 的非常简单的表达式很懒惰。在 omnistaging 之后,我们可以删除该惰性子语言并简化 JAX 内部机制。)

创建 mask 未阶段性输出的原因是,在 omnistaging 之前,jit 基于数据依赖性运行。也就是说,jit 仅阶段性输出函数中那些对参数具有数据依赖性的操作。控制流原语和 pmap 的行为类似。在 select_tril 的情况下,构造常量 mask 的操作对参数 x 没有数据依赖性,因此它们未阶段性输出;只有 lax.select 调用具有数据依赖性。

使用 omnistaging,jit 转换函数的动态上下文中所有 jax.numpy 调用都阶段性输出到 XLA。也就是说,在 omnistaging 之后,XLA 为 select_tril 看到的计算是

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}

当 omnistaging 打开时,可能出现哪些问题?#

由于在 jitpmap 的动态上下文中,将所有 jax.numpy 操作从 Python 阶段性输出到 XLA,因此以前可以正常工作的一些代码可能会开始引发严重的错误。如下所述,这些行为在 omnistaging 之前已经存在错误,但 omnistaging 使它们变成了硬错误。

使用 jax.numpy 进行形状计算#

示例#

from jax import jit
import jax.numpy as jnp

@jit
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return x.reshape((size,))

ex1(jnp.ones((3, 4)))

错误消息#

[... full traceback ...]
  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:

  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
    from line ex1.py:6 (ex1)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.net.cn/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

解释#

使用 omnistaging,我们不能像上面使用 jnp.prod 那样使用 jax.numpy 进行形状计算,因为在 jit 函数的动态上下文中,这些操作将从 Python 中阶段性输出为在执行时计算的值,但我们需要它们成为编译时(因此也是跟踪时)常量。

在 omnistaging 之前,此代码不会引发错误,但这是一个常见的性能错误:jnp.prod 计算将在跟踪时在设备上执行,这意味着额外的编译、传输、同步、分配以及可能的内存碎片。

解决方案#

解决方案很简单,就是使用原始的 numpy 进行这些形状计算。我们不仅避免了错误,而且还将计算保留在主机上(并降低了开销)。

这个问题在代码中非常常见,以至于我们尝试使错误消息特别好。除了堆栈跟踪显示抽象跟踪器值导致问题的位置(omni.py:10 中的 jnp.reshape 行)之外,我们还解释了为什么此值首先成为跟踪器,方法是指向上游原始操作,该操作导致它成为抽象跟踪器(来自 omni.py:9 中 jnp.prodreduce_prod)以及跟踪器所属的 jit 装饰函数(omni.py:6 中的 ex1)。

副作用#

示例#

from jax import jit
from jax import random

key = random.PRNGKey(0)

def init():
  global key
  key, subkey = random.split(key)
  return random.normal(subkey, ())

print(init())  # -1.2515389
print(init())  # -0.58665067

init = jit(init)
print(init())  # 0.48648298
print(init())  # 0.48648298  !!

最后一次调用具有重复的随机性,但没有硬错误,因为我们没有重新执行 Python。但是,如果我们查看 key,我们会看到一个转义的跟踪器当 omnistaging 开启时

print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

在 omnistaging 之前,random.split 调用不会被阶段性输出,因此我们不会得到转义的跟踪器。代码仍然存在错误,因为 jitted 函数不会重现原始函数的语义(因为重复使用相同的 PRNG 密钥),最终是由于副作用。

当 omnistaging 开启时,如果我们再次触摸 key,我们将收到一个转义的跟踪器错误

random.normal(key, ())

错误消息#

[... full stack trace …]
  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).

解释#

我们发现的第二大类 omnistaging 问题与副作用代码有关。此代码已经通过转换有效函数而使 JAX 保修失效,但由于预 omnistaging“跟踪时常量折叠”行为,某些副作用函数仍然可以正确运行。Omnistaging 捕获了更多此类错误。

解决方案#

解决方案是识别依赖副作用的 JAX 转换函数,并重写它们以使其无效。

基于 XLA 优化的小数值差异#

因为使用 omnistaging,更多的计算被阶段性输出到 XLA,而不是一些在跟踪时执行的计算,这可能会导致浮点运算的重新排序。因此,我们已经看到数值行为发生变化,导致公差过紧的测试在 omnistaging 打开时失败。

依赖于已更改的 JAX 内部 API#

Omnistaging 涉及对 JAX 核心代码的一些重大修订,包括删除或更改内部函数。任何依赖于此类内部 JAX API 的代码都可能在 omnistaging 打开时中断,无论是构建错误(来自 pytype)还是运行时错误。

触发 XLA 编译时错误#

由于 omnistaging 涉及将更多代码阶段性输出到 XLA,因此我们已经看到它在某些后端触发了预先存在的 XLA 编译时错误。处理这些问题的最佳方法是报告它们,以便我们可以与 XLA 团队合作进行修复。