Omnistaging#
mattjj@ 2020 年 9 月 25 日
这更像是一份升级指南,而不是设计文档。
目录#
简而言之#
发生了什么?#
JAX 的一个名为“omnistaging”的追踪基础设施的更改( jax-ml/jax#3370 )已在 jax==0.2.0 中启用。此更改提高了内存性能、跟踪执行时间,并简化了 JAX 内部结构,但可能会导致一些现有代码出错。出错通常是由于代码错误导致的,因此从长远来看,最好修复这些错误,但也可以禁用 omnistaging 作为临时解决方法。我们很乐意帮助您解决问题!
我如何知道 omnistaging 是否破坏了我的代码?#
判断 omnistaging 是否是罪魁祸首的最简单方法是禁用 omnistaging 并查看问题是否消失。请参阅下面的 启用 omnistaging 后可能出现什么问题? 部分。
我现在如何禁用 omnistaging?#
注意:这适用于 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及更高版本中无法禁用 omnistaging。
暂时可以通过以下方式禁用 omnistaging:
将 shell 环境变量
JAX_OMNISTAGING设置为假值;如果您的代码使用 absl 解析标志,请将布尔标志
jax_omnistaging设置为假值;在主文件的顶部附近使用此语句
jax.config.disable_omnistaging()
如何修复 omnistaging 暴露的 bug?#
迄今为止,omnistaging 最常见的问题是使用 jax.numpy 计算形状值或其他跟踪时常量。请参阅下面的代码块,了解快速示例,并获取其他问题的详细信息,请参阅 启用 omnistaging 后可能出现什么问题? 部分。
不要这样做
@jit
def f(x):
input_size = jnp.prod(x.shape)
if input_size > 100:
...
这样做
import numpy as np
@jit
def f(x):
input_size = np.prod(x.shape)
if input_size > 100:
...
与其将 jax.numpy 视为 numpy 的即插即用替代品,不如现在更好地将其视为仅在您希望在加速器(如 GPU)上执行计算时使用 jax.numpy 操作。
什么是“omnistaging”,它有什么用?#
Omnistaging 是 JAX 核心升级的名称,旨在将更多计算从逐个 op 的 Python 转移到 XLA,并避免 jit、pmap 和控制流原语中的任何“跟踪时常量折叠”。因此,omnistaging 提高了 JAX 的内存性能(有时是巨大的),方法是减少跟踪过程中的碎片,并为 XLA 生成更少的、大的编译时常量。它还可以通过消除跟踪时的逐个 op 执行来提高跟踪性能。此外,omnistaging 简化了 JAX 核心内部结构,修复了许多待处理的 bug,并为重要的未来功能奠定了基础。
“omnistaging”这个名字的意思是尽可能地进行分阶段处理。
玩具示例#
像 jit 和 pmap 这样的 JAX 变换会将计算分阶段到 XLA。也就是说,我们将它们应用于包含多个基本操作的函数,以便不是一个接一个地从 Python 执行,而是将所有操作都作为一次端到端优化的 XLA 计算的一部分。
但具体哪些操作会被分阶段处理?直到 omnistaging,JAX 仅根据数据依赖性分阶段处理计算。下面是一个示例函数,以及它在 omnistaging 更改之前分阶段处理的 XLA HLO 程序。
from jax import jit
import jax.numpy as jnp
@jit
def f(x):
y = jnp.add(1, 1)
return x * y
f(3)
ENTRY jit_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(2)
multiply.4 = s32[] multiply(parameter.1, constant.3)
ROOT tuple.5 = (s32[]) tuple(multiply.4)
}
请注意,add 操作未被分阶段处理。相反,我们只看到一个乘法。
以下是 omnistaging 更改之后此函数生成的 HLO。
ENTRY jit_f.8 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(1)
constant.4 = s32[] constant(1)
add.5 = s32[] add(constant.3, constant.4)
multiply.6 = s32[] multiply(parameter.1, add.5)
ROOT tuple.7 = (s32[]) tuple(multiply.6)
}
稍微不那么玩具的例子#
这是一个在实际中可能出现的、稍微不那么玩具的例子,当我们想要创建布尔掩码时。
import jax.numpy as jnp
from jax import lax
@jit
def select_tril(x):
mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
return lax.select(mask, x, jnp.zeros_like(x)) # lax.select is like jnp.where
x = np.arange(12).reshape((3, 4))
select_tril(x)
omnistaging之前
ENTRY jit_select_tril.8 {
constant.3 = pred[] constant(false)
constant.1 = pred[3,4]{1,0} constant({...})
parameter.2 = s32[3,4]{1,0} parameter(0)
constant.4 = s32[] constant(0)
broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}
执行了 select 操作,但构造常量 mask 的操作没有。与其分阶段处理,不如在 Python 跟踪时逐个 op 执行构造 mask 的操作,XLA 只看到一个表示 mask 值的编译时常量 constant.1。这很不理想,因为如果我们分阶段处理了构造 mask 的操作,XLA 就可以将它们融合到 select 中,完全避免物化结果。因此,我们最终会用一个潜在的大常量浪费内存,浪费时间分派多个未融合的逐个 op XLA 计算,甚至可能导致内存碎片。
(与构造 jnp.zeros_like(x) 的零数组对应的 broadcast 操作被分阶段处理,因为 JAX 对非常简单的表达式是惰性的( jax-ml/jax#1668 )。在 omnistaging 之后,我们可以删除那个惰性子语言并简化 JAX 内部结构。
构造 mask 没有被分阶段处理的原因是,在 omnistaging 之前,jit 是基于数据依赖性操作的。也就是说,jit 只分阶段处理函数中与参数有数据依赖性的操作。控制流原语和 pmap 的行为类似。在 select_tril 的情况下,构造常量 mask 的操作与参数 x 没有数据依赖性,因此它们没有被分阶段处理;只有 lax.select 调用有数据依赖性。
通过 omnistaging,在 jit 转换的函数的动态上下文中,所有 jax.numpy 调用都会被分阶段处理到 XLA。也就是说,在 omnistaging 之后,XLA 为 select_tril 看到的计算是:
ENTRY jit_select_tril.16 {
constant.4 = pred[] constant(false)
iota.1 = s32[3]{0} iota(), iota_dimension=0
broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
reshape.7 = s32[3]{0} reshape(broadcast.5)
broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
iota.2 = s32[4]{0} iota(), iota_dimension=0
broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
reshape.9 = s32[4]{0} reshape(broadcast.6)
broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
parameter.3 = s32[3,4]{1,0} parameter(0)
constant.12 = s32[] constant(0)
broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}
启用 omnistaging 后可能出现什么问题?#
由于在 jit 或 pmap 的动态上下文中,所有 jax.numpy 操作都从 Python 分阶段处理到 XLA,因此一些以前可以正常工作的代码可能会开始出现明显的错误。如下所述,这些行为在 omnistaging 之前已经是错误的,但 omnistaging 使它们成为硬错误。
使用 jax.numpy 进行形状计算#
示例#
from jax import jit
import jax.numpy as jnp
@jit
def ex1(x):
size = jnp.prod(jnp.array(x.shape))
return x.reshape((size,))
ex1(jnp.ones((3, 4)))
错误消息#
[... full traceback ...]
File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:
operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
from line ex1.py:6 (ex1)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.net.cn/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
解释#
使用 omnistaging 时,我们不能像上面使用 jnp.prod 那样在 JIT 函数的动态上下文中为形状计算使用 jax.numpy,因为那些操作将被分阶段处理为在执行时计算的值,但我们需要它们是编译时(因此是跟踪时)常量。
在 omnistaging 之前,这段代码不会报错,但这是一个常见的性能 bug:jnp.prod 计算将在跟踪时在设备上执行,这意味着额外的编译、传输、同步、分配,以及潜在的内存碎片。
解决方案#
解决方案很简单,就是使用原始的 numpy 来执行此类形状计算。我们不仅避免了错误,而且将计算保留在主机上(并且开销更低)。
此问题在代码中非常普遍,我们试图使错误消息特别好。除了显示抽象跟踪器值导致问题的堆栈跟踪(在 omni.py:10 处的 jnp.reshape 行)之外,我们还通过指向导致该值成为跟踪器的上游原始操作(来自 jnp.prod 的 reduce_prod,在 omni.py:9 处)以及该跟踪器所属的 jit 装饰函数(在 omni.py:6 处的 ex1)来解释为什么该值会成为跟踪器。
副作用#
示例#
from jax import jit
from jax import random
key = random.PRNGKey(0)
def init():
global key
key, subkey = random.split(key)
return random.normal(subkey, ())
print(init()) # -1.2515389
print(init()) # -0.58665067
init = jit(init)
print(init()) # 0.48648298
print(init()) # 0.48648298 !!
最后一个调用有重复的随机性但没有硬错误,因为我们没有重新执行 Python。但是如果我们查看 key,当 omnistaging 开启时,我们会看到一个逃逸的跟踪器。
print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
在 omnistaging 之前,random.split 调用不会被分阶段处理,因此我们不会得到逃逸的跟踪器。代码仍然是错误的,因为 JIT 转换的函数不会重现原始函数的语义(由于重复使用相同的 PRNG 密钥),最终是由于副作用。
开启 omnistaging 后,如果我们再次触碰 key,我们会得到一个逃逸的跟踪器错误。
random.normal(key, ())
错误消息#
[... full stack trace …]
File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
解释#
我们发现的第二大类 omnistaging 问题与副作用代码有关。这些代码通过转换有副作用的函数来破坏 JAX 的保证,但由于预 omnistaging 的“跟踪时常量折叠”行为,一些有副作用的函数仍然可以正常工作。Omnistaging 捕获了更多此类错误。
解决方案#
解决方案是识别依赖于副作用的 JAX 转换函数,并重写它们以消除副作用。
基于 XLA 优化的微小数值差异#
由于 omnistaging 将更多计算分阶段到 XLA,而不是在跟踪时执行一些计算,这可能会导致浮点运算重新排序。因此,我们已经看到数值行为发生变化,导致在启用 omnistaging 时,具有过于严格容差的测试失败。
依赖于已更改的 JAX 内部 API#
Omnistaging 涉及 JAX 核心代码的一些重大修订,包括删除或更改内部函数。任何依赖于此类内部 JAX API 的代码在启用 omnistaging 时都可能中断,出现构建错误(来自 pytype)或运行时错误。
触发 XLA 编译时 bug#
由于 omnistaging 涉及将更多代码分阶段到 XLA,我们已经看到它在某些后端上触发了预先存在的 XLA 编译时 bug。处理这些问题的最佳方法是报告它们,以便我们与 XLA 团队合作进行修复。