Omnistaging#
mattjj@ 2020 年 9 月 25 日
这更像是一份升级指南,而不是设计文档。
目录#
tl;dr#
发生了什么?#
JAX 的追踪基础设施进行了一项名为“omnistaging”(jax-ml/jax#3370)的更改,该更改在 jax==0.2.0 中启用。此更改提高了内存性能、跟踪执行时间并简化了 JAX 内部结构,但可能会导致一些现有代码中断。中断通常是代码存在错误的结果,因此从长远来看,最好修复这些错误,但也可以暂时禁用 omnistaging 作为一种临时解决方法。我们很乐意帮助您修复这些错误!
我如何知道我的代码是否因 omnistaging 而中断?#
判断是否是 omnistaging 导致的最简单方法是禁用 omnistaging,然后看看问题是否消失。请参阅下面的启用 omnistaging 后可能出现哪些问题? 部分。
我目前如何禁用 omnistaging?#
注意:这适用于 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及更高版本中无法禁用 omnistaging
暂时可以通过以下方式禁用 omnistaging:
将 shell 环境变量
JAX_OMNISTAGING
设置为 false 值;如果您的代码使用 absl 解析标志,则将布尔标志
jax_omnistaging
设置为 false 值;在您的主文件顶部附近使用此语句
jax.config.disable_omnistaging()
我如何修复 omnistaging 暴露的错误?#
到目前为止,omnistaging 最常见的问题是使用 jax.numpy
来计算形状值或其他跟踪时常量。有关快速示例,请参阅下面的代码块,有关完整详细信息和其他问题,请参阅启用 omnistaging 后可能出现哪些问题?部分。
不要这样做
@jit
def f(x):
input_size = jnp.prod(x.shape)
if input_size > 100:
...
而要这样做
import numpy as np
@jit
def f(x):
input_size = np.prod(x.shape)
if input_size > 100:
...
现在,不要将 jax.numpy
视为 numpy
的直接替代品,最好将其视为仅在想要在加速器(如 GPU)上执行计算时才使用 jax.numpy
操作。
什么是“omnistaging”,为什么它有用?#
Omnistaging 是 JAX 核心升级的名称,旨在将更多计算从逐个操作的 Python 转移到 XLA,并避免在 jit
、pmap
和控制流原语中进行任何“跟踪时常量折叠”。因此,omnistaging 通过减少跟踪期间的碎片以及为 XLA 生成更少的编译时常量,从而提高了 JAX 的内存性能(有时会显着提高)。它还可以通过消除跟踪时的逐个操作执行来提高跟踪性能。此外,omnistaging 简化了 JAX 核心内部结构,修复了许多未解决的错误,并为即将推出的重要功能奠定了基础。
名称“omnistaging”表示尽可能多地转移所有内容。
玩具示例#
JAX 转换(如 jit
和 pmap
)将计算转移到 XLA。也就是说,我们将它们应用于包含多个原始操作的函数,这样,这些操作不是从 Python 一次执行一个,而是都属于一个端到端优化的 XLA 计算。
但是,究竟哪些操作会被转移?在 omnistaging 之前,JAX 仅基于数据依赖性来转移计算。以下是一个示例函数,以及它在 omnistaging 更改之前转移的 XLA HLO 程序
from jax import jit
import jax.numpy as jnp
@jit
def f(x):
y = jnp.add(1, 1)
return x * y
f(3)
ENTRY jit_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(2)
multiply.4 = s32[] multiply(parameter.1, constant.3)
ROOT tuple.5 = (s32[]) tuple(multiply.4)
}
请注意,add
操作未被转移。相反,我们只看到乘法。
以下是从此函数生成的 HLO,在 omnistaging 更改之后
ENTRY jit_f.8 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(1)
constant.4 = s32[] constant(1)
add.5 = s32[] add(constant.3, constant.4)
multiply.6 = s32[] multiply(parameter.1, add.5)
ROOT tuple.7 = (s32[]) tuple(multiply.6)
}
稍微不那么玩具的示例#
以下是一个不太玩具化的示例,当我们要创建布尔掩码时,实际上可能会出现这种情况
import jax.numpy as jnp
from jax import lax
@jit
def select_tril(x):
mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
return lax.select(mask, x, jnp.zeros_like(x)) # lax.select is like jnp.where
x = np.arange(12).reshape((3, 4))
select_tril(x)
在 omnistaging 之前
ENTRY jit_select_tril.8 {
constant.3 = pred[] constant(false)
constant.1 = pred[3,4]{1,0} constant({...})
parameter.2 = s32[3,4]{1,0} parameter(0)
constant.4 = s32[] constant(0)
broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}
select
操作被转移,但是用于构造常量 mask
的操作则没有。构建 mask
的操作不是被转移,而是在 Python 跟踪时逐个操作执行,并且 XLA 只看到一个编译时常量 constant.1
,表示 mask
的值。这很不幸,因为如果我们转移了用于构造 mask
的操作,XLA 可以将它们融合到 select
中,并完全避免具体化结果。因此,我们最终会浪费内存(可能很大),浪费时间分派多个未融合的逐个操作的 XLA 计算,甚至可能会碎片化内存。
(对应于 jnp.zeros_like(x)
的零数组构造的 broadcast
被转移,因为 JAX 对于来自 jax-ml/jax#1668 的非常简单的表达式是延迟的。在 omnistaging 之后,我们可以删除该延迟的子语言并简化 JAX 内部结构。)
创建 mask
没有被转移的原因是,在 omnistaging 之前,jit
基于数据依赖性运行。也就是说,jit
只会转移函数中那些对参数有数据依赖性的操作。控制流原语和 pmap
的行为类似。在 select_tril
的情况下,构造常量 mask
的操作对参数 x 没有数据依赖性,因此它们不会被转移;只有 lax.select
调用具有数据依赖性。
借助 omnistaging,jit
转换函数的动态上下文中所有 jax.numpy
调用都会被转移到 XLA。也就是说,在 omnistaging 之后,XLA 对于 select_tril
看到的计算是
ENTRY jit_select_tril.16 {
constant.4 = pred[] constant(false)
iota.1 = s32[3]{0} iota(), iota_dimension=0
broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
reshape.7 = s32[3]{0} reshape(broadcast.5)
broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
iota.2 = s32[4]{0} iota(), iota_dimension=0
broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
reshape.9 = s32[4]{0} reshape(broadcast.6)
broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
parameter.3 = s32[3,4]{1,0} parameter(0)
constant.12 = s32[] constant(0)
broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}
启用 omnistaging 后可能出现哪些问题?#
由于在 jit
或 pmap
的动态上下文中,将所有 jax.numpy
操作从 Python 转移到 XLA,因此之前正常运行的某些代码可能会开始引发明显的错误。如下所述,这些行为在 omnistaging 之前已经存在错误,但 omnistaging 将它们变成了硬错误。
使用 jax.numpy
进行形状计算#
示例#
from jax import jit
import jax.numpy as jnp
@jit
def ex1(x):
size = jnp.prod(jnp.array(x.shape))
return x.reshape((size,))
ex1(jnp.ones((3, 4)))
错误消息#
[... full traceback ...]
File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:
operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
from line ex1.py:6 (ex1)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.net.cn/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
解释#
借助 omnistaging,我们不能像上面使用 jnp.prod
一样使用 jax.numpy
进行形状计算,因为在 jit 函数的动态上下文中,这些操作将从 Python 中转移出去,作为在执行时要计算的值,但我们需要它们成为编译时(因此是跟踪时)常量。
在 omnistaging 之前,此代码不会引发错误,但这是一种常见的性能错误:jnp.prod
计算将在跟踪时在设备上执行,这意味着额外的编译、传输、同步、分配以及潜在的内存碎片。
解决方案#
解决方案很简单,就是使用原始的 numpy
进行此类形状计算。这样不仅可以避免错误,还可以将计算保留在主机上(并降低开销)。
此问题在代码中很常见,因此我们试图使错误消息特别好。除了堆栈跟踪显示抽象跟踪器值导致问题的位置(完整堆栈跟踪中的 jnp.reshape
行,位于 omni.py:10),我们还通过指向导致其成为抽象跟踪器的上游原始操作(来自 omni.py:9 上的 jnp.prod
的 reduce_prod
)以及跟踪器所属的 jit
修饰的函数(位于 omni.py:6 上的 ex1
)来解释为什么此值首先成为跟踪器。
副作用#
示例#
from jax import jit
from jax import random
key = random.PRNGKey(0)
def init():
global key
key, subkey = random.split(key)
return random.normal(subkey, ())
print(init()) # -1.2515389
print(init()) # -0.58665067
init = jit(init)
print(init()) # 0.48648298
print(init()) # 0.48648298 !!
最后一次调用具有重复的随机性,但没有硬错误,因为我们没有重新执行 Python。但是,如果我们查看 key
,我们会看到一个逃逸的跟踪器,当 omnistaging 开启时
print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
在 omnistaging 之前,random.split
调用不会被转移,因此我们不会得到逃逸的跟踪器。该代码仍然存在错误,因为 jitted 函数不会重现原始函数的语义(由于重复使用相同的 PRNG 密钥),最终是由于副作用。
启用 omnistaging 后,如果我们再次接触 key
,我们将收到一个逃逸的跟踪器错误
random.normal(key, ())
错误消息#
[... full stack trace …]
File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
解释#
我们发现的第二大类 omnistaging 问题与副作用代码有关。此代码已经通过转换有副作用的函数而违反了 JAX 保证,但是由于 omnistaging 之前的“跟踪时常量折叠”行为,某些有副作用的函数仍然可以正常运行。Omnistaging 会捕获更多此类错误。
解决方案#
解决方案是识别依赖副作用的 JAX 转换函数,并重写它们,使其不产生副作用。
基于 XLA 优化的微小数值差异#
由于使用全暂存 (omnistaging) 时,更多的计算被暂存到 XLA,而不是在跟踪时执行,这可能会导致浮点运算的重新排序。因此,我们观察到数值行为发生了变化,导致当启用全暂存时,具有过紧容差的测试会失败。
依赖于已更改的 JAX 内部 API#
全暂存涉及对 JAX 核心代码的一些重大修订,包括删除或更改内部函数。任何依赖于此类 JAX 内部 API 的代码在启用全暂存时都可能中断,要么出现构建错误 (来自 pytype),要么出现运行时错误。
触发 XLA 编译时错误#
由于全暂存涉及将更多代码暂存到 XLA,我们发现它会在某些后端触发预先存在的 XLA 编译时错误。处理这些问题的最佳方法是报告它们,以便我们可以与 XLA 团队合作进行修复。