Omnistaging#

mattjj@ 2020 年 9 月 25 日

这更像是一份升级指南,而非设计文档。

目录#

太长不看#

发生了什么?#

对 JAX 跟踪基础设施的一项名为“Omnistaging”的更改(jax-ml/jax#3370)已在 jax==0.2.0 版本中启用。这项更改提高了内存性能、跟踪执行时间,并简化了 JAX 内部结构,但也可能导致一些现有代码中断。中断通常是由于代码存在 bug 导致的,因此从长远来看,最好修复这些 bug,但 Omnistaging 也可以作为临时解决方案被禁用。我们乐意协助您进行修复!

我如何知道 Omnistaging 是否破坏了我的代码?#

判断 Omnistaging 是否是罪魁祸首最简单的方法是禁用 Omnistaging,看看问题是否消失。请参阅下面的开启 Omnistaging 后可能出现哪些问题?部分。

我现在如何禁用 Omnistaging?#

注意:这适用于 JAX 0.2.0 至 0.2.11 版本;Omnistaging 在 JAX 0.2.12 及更高版本中无法禁用

暂时可以通过以下方式禁用 Omnistaging:

  1. 将 shell 环境变量 JAX_OMNISTAGING 设置为某个表示“假”的值;

  2. 如果您的代码使用 absl 解析标志,则将布尔标志 jax_omnistaging 设置为某个表示“假”的值;

  3. 在主文件顶部附近使用此语句

jax.config.disable_omnistaging()

如何修复 Omnistaging 暴露的 bug?#

目前,Omnistaging 最常见的问题是使用 jax.numpy 计算形状值或其他跟踪时常量。有关快速示例,请参阅下面的代码块;有关完整详情及其他问题,请参阅开启 Omnistaging 后可能出现哪些问题?部分。

不要这样做

@jit
def f(x):
  input_size = jnp.prod(x.shape)
  if input_size > 100:
    ...

这样做

import numpy as np

@jit
def f(x):
  input_size = np.prod(x.shape)
  if input_size > 100:
    ...

不再将 jax.numpy 视为 numpy 的直接替代品,现在最好只在您希望在加速器(例如 GPU)上执行计算时才考虑使用 jax.numpy 操作。

什么是“Omnistaging”以及它为何有用?#

Omnistaging 是 JAX 核心升级的名称,旨在将更多的逐操作 Python 计算暂存到 XLA,并避免在 jit`、pmap` 和控制流原语中进行“跟踪时常量折叠”。因此,Omnistaging 提高了 JAX 的内存性能(有时是显著提升),它通过减少跟踪期间的碎片化和为 XLA 生成更少的大型编译时常量来实现这一点。它还可以通过消除跟踪时的逐操作执行来提高跟踪性能。此外,Omnistaging 简化了 JAX 核心内部结构,修复了许多未解决的 bug,并为即将推出的重要功能奠定了基础。

“Omnistaging”这个名字意味着暂存所有可能的内容。

简单示例#

诸如 jitpmap 等 JAX 变换会将计算暂存到 XLA。也就是说,我们将它们应用于包含多个原语操作的函数,这样这些操作就不会从 Python 中逐个执行,而是作为单个端到端优化的 XLA 计算的一部分。

但究竟哪些操作会被暂存出去呢?在 Omnistaging 之前,JAX 只基于数据依赖暂存计算。这是一个示例函数,后面是它在 Omnistaging 更改之前暂存的 XLA HLO 程序

from jax import jit
import jax.numpy as jnp

@jit
def f(x):
  y = jnp.add(1, 1)
  return x * y

f(3)
ENTRY jit_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(parameter.1, constant.3)
  ROOT tuple.5 = (s32[]) tuple(multiply.4)
}

请注意,add 操作没有被暂存。相反,我们只看到一个乘法操作。

这是 Omnistaging 更改之后从此函数生成的 HLO

ENTRY jit_f.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = s32[] parameter(0)
  constant.3 = s32[] constant(1)
  constant.4 = s32[] constant(1)
  add.5 = s32[] add(constant.3, constant.4)
  multiply.6 = s32[] multiply(parameter.1, add.5)
  ROOT tuple.7 = (s32[]) tuple(multiply.6)
}

稍复杂的示例#

这是一个不那么简单的示例,它在实际创建布尔掩码时可能会出现

import jax.numpy as jnp
from jax import lax

@jit
def select_tril(x):
  mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
  return lax.select(mask, x, jnp.zeros_like(x))  # lax.select is like jnp.where

x = np.arange(12).reshape((3, 4))
select_tril(x)

Omnistaging 之前

ENTRY jit_select_tril.8 {
  constant.3 = pred[] constant(false)
  constant.1 = pred[3,4]{1,0} constant({...})
  parameter.2 = s32[3,4]{1,0} parameter(0)
  constant.4 = s32[] constant(0)
  broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
  select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
  ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}

select 操作被暂存,但构建常量 mask 的操作没有被暂存。构建 mask 的操作没有被暂存,而是在 Python 跟踪时逐个执行,XLA 只看到一个表示 mask 值的编译时常量 constant.1`。这很不幸,因为如果我们暂存了构建 mask 的操作,XLA 本可以将它们融合到 select 中,从而完全避免具体化结果。结果是我们浪费了内存来存储一个可能很大的常量,浪费了时间来调度多个未融合的逐操作 XLA 计算,甚至可能导致内存碎片化。

(对应于为 jnp.zeros_like(x) 构建零数组的 broadcast 被暂存,因为 JAX 对来自 jax-ml/jax#1668 的非常简单的表达式采取了惰性处理。在 Omnistaging 之后,我们可以删除这种惰性子语言并简化 JAX 内部结构。)

之所以没有暂存 mask 的创建,是因为在 Omnistaging 之前,jit 是基于数据依赖进行操作的。也就是说,jit 只暂存函数中那些对参数存在数据依赖的操作。控制流原语和 pmap 的行为类似。在 select_tril 的情况下,构建常量 mask 的操作与参数 x 没有数据依赖,因此它们没有被暂存;只有 lax.select 调用具有数据依赖。

启用 Omnistaging 后,jit 变换函数动态上下文中的所有 jax.numpy 调用都将被暂存到 XLA。也就是说,启用 Omnistaging 后,XLA 为 select_tril 看到的计算是

ENTRY jit_select_tril.16 {
  constant.4 = pred[] constant(false)
  iota.1 = s32[3]{0} iota(), iota_dimension=0
  broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
  reshape.7 = s32[3]{0} reshape(broadcast.5)
  broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
  iota.2 = s32[4]{0} iota(), iota_dimension=0
  broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
  reshape.9 = s32[4]{0} reshape(broadcast.6)
  broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
  compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
  parameter.3 = s32[3,4]{1,0} parameter(0)
  constant.12 = s32[] constant(0)
  broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
  select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
  ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}

开启 Omnistaging 后可能出现哪些问题?#

由于在 jitpmap 的动态上下文中,所有 jax.numpy 操作都从 Python 暂存到 XLA,一些以前能正常工作的代码可能会开始抛出明显的错误。正如下面解释的,这些行为在 Omnistaging 之前就已经存在 bug,但 Omnistaging 使它们变成了硬性错误。

使用 jax.numpy 进行形状计算#

示例#

from jax import jit
import jax.numpy as jnp

@jit
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return x.reshape((size,))

ex1(jnp.ones((3, 4)))

错误消息#

[... full traceback ...]
  File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The error arose in jax.numpy.reshape.

While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:

  operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
    from line ex1.py:6 (ex1)

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.

See https://jax.net.cn/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

解释#

启用 Omnistaging 后,我们不能像上面使用 jnp.prod 那样使用 jax.numpy 进行形状计算,因为在 jit 函数的动态上下文中,这些操作将作为在执行时计算的值从 Python 中暂存出去,而我们却需要它们是编译时(以及跟踪时)常量。

在 Omnistaging 之前,这段代码不会引发错误,但它是一个常见的性能 bug:jnp.prod 计算将在跟踪时在设备上执行,这意味着额外的编译、数据传输、同步、内存分配,并可能导致内存碎片化。

解决方案#

解决方案很简单,只需使用原始的 numpy 进行这类形状计算。这样不仅可以避免错误,还可以将计算保留在主机上(且开销更低)。

这个问题在代码中很常见,因此我们尝试让错误消息尽可能清晰。除了显示抽象跟踪器值导致问题的堆栈跟踪(完整堆栈跟踪中 omni.py:10 的 jnp.reshape 行)之外,我们还通过指出导致它成为抽象跟踪器的上游原始操作(omni.py:9 的 jnp.prod 中的 reduce_prod)以及跟踪器所属的 jit 装饰函数(omni.py:6 的 ex1)来解释为什么这个值首先成为了跟踪器。

副作用#

示例#

from jax import jit
from jax import random

key = random.PRNGKey(0)

def init():
  global key
  key, subkey = random.split(key)
  return random.normal(subkey, ())

print(init())  # -1.2515389
print(init())  # -0.58665067

init = jit(init)
print(init())  # 0.48648298
print(init())  # 0.48648298  !!

最后的那个调用具有重复的随机性但没有硬性错误,因为我们没有重新执行 Python。但是,如果我们查看 key`,我们会在开启 Omnistaging 时看到一个逃逸的跟踪器

print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>

在 Omnistaging 之前,random.split 调用不会被暂存,因此我们不会得到逃逸的跟踪器。代码仍然存在 bug,因为 jit 化的函数无法复现原始函数的语义(由于重复使用了相同的 PRNG 密钥),最终是由于副作用造成的。

开启 Omnistaging 后,如果再次触及 key`,我们将得到一个逃逸跟踪器错误

random.normal(key, ())

错误消息#

[... full stack trace …]
  File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
    raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).

解释#

我们发现的第二大类 Omnistaging 问题与带有副作用的代码有关。这种代码通过转换有副作用的函数,已经违反了 JAX 的约定,但由于 Omnistaging 之前的“跟踪时常量折叠”行为,一些有副作用的函数仍然可以正常运行。Omnistaging 捕捉了更多这类错误。

解决方案#

解决方案是识别那些依赖于副作用的 JAX 转换函数,并重写它们使其不再具有副作用。

基于 XLA 优化的小幅数值差异#

因为启用 Omnistaging 后,更多的计算被暂存到 XLA,而不是在跟踪时执行,这可能会导致浮点运算的重新排序。因此,我们观察到数值行为发生了变化,导致在开启 Omnistaging 时,那些具有过于严格容差的测试会失败。

依赖于已更改的 JAX 内部 API#

Omnistaging 涉及对 JAX 核心代码进行了一些重大修订,包括删除或更改内部函数。任何依赖此类 JAX 内部 API 的代码在开启 Omnistaging 时都可能中断,无论是出现构建错误(来自 pytype)还是运行时错误。

触发 XLA 编译时错误#

由于 Omnistaging 涉及将更多代码暂存到 XLA,我们发现它会在某些后端触发预先存在的 XLA 编译时 bug。解决这些问题的最佳方法是报告它们,以便我们与 XLA 团队合作进行修复。