针对可 JAX 转换的函数的自定义 JVP/VJP 规则#
这是一份设计文档,解释了 jax.custom_jvp
和 jax.custom_vjp
的设计和实现背后的一些思考。有关面向用户的文档,请参阅教程笔记本。
在 JAX 中,有两种定义微分规则的方法
使用
jax.custom_jvp
和jax.custom_vjp
为已经可 JAX 转换的 Python 函数定义自定义微分规则;以及定义新的
core.Primitive
实例及其所有转换规则,例如调用来自其他系统的函数,如求解器、模拟器或通用数值计算系统。
本文档仅关于 #1。
目录#
目标#
我们希望用户自定义其代码的前向和/或反向模式微分行为。此自定义
在工作方式以及如何与其他 JAX 转换组合方面,应具有清晰且一致的语义;并且
应该灵活地支持诸如 Autograd 和 PyTorch 中的用例和工作流程,包括涉及 Python 控制流的微分以及 NaN 调试的工作流程。
作为 JAX 开发人员,我们希望编写库函数,例如 logit
和 expit
,这些函数是根据其他原语定义的,但出于微分的目的,它们具有类似原语的行为,因为我们希望为它们定义自定义微分规则,这些规则可能在数值上更稳定或性能更高。特别是,我们不想为像 logit
和 expit
这样的函数指定 vmap
或 jit
规则。
作为一个延伸目标,我们希望使 JAX 成为高级用户添加诸如 fixed_point
、odeint
等高阶函数的自定义微分规则的理想环境;此设计文档不会解决该问题,但我们希望确信我们不会排除该问题的良好解决方案。
也就是说,我们的主要目标是
次要目标是 3. 清理和简化用户体验(符号零、kwargs 等)4.朝着用户可以轻松添加 fixed_point
、odeint
、root
等的世界迈进。
总的来说,我们希望关闭 #116、#1097、#1249、#1275、#1366、#1723、#1670、#1875、#1938,并替换 custom_transforms 机制(来自 #636、#818 等)。
非目标#
以下是我们不打算实现的目标
custom_transforms
机制旨在提供一种转换通用的机制来定制行为,原则上(尽管实际上从未真正使用过)允许用户自定义任何转换的规则,同时以某种方式继承其他转换的“透明”行为。我们反而仅打算解决微分(JVP 和 VJP,分别)的自定义问题。微分是唯一实际请求的情况,通过专门针对微分,我们可以降低复杂性并提高灵活性。要控制所有规则,只需编写一个原语即可。我们不会优先考虑数学上的美感,而不是用户端的灵活性和清晰度,以及实现端的简单性。特别是,虽然自定义 VJP 签名
a -> (b, CT b --o CT a)
在数学上令人满意,但如果由于返回类型中的闭包而难以在 Python 机制中实现,那么我们可以更明确地处理残差。序列化支持,即分段输出的序列化程序表示可以加载并进一步进行 JAX 转换,而不仅仅是评估,目前不在这些自定义 JVP/VJP 转换规则的范围内。序列化不仅对希望保存其计算的某些表示形式(并在加载后对其进行转换)的研究人员有用,而且对于未来的考虑因素(例如,在 Python 外部实现 jaxpr 转换或将 jaxpr 用作 MLIR 方言)也可能有用。通过将此定义为本设计的非目标,我们对可以在何处存储 Python 可调用对象施加的约束更少。
主要问题描述#
vmap 移除自定义 jvp 的语义问题#
vmap 移除自定义 jvp 的语义问题是 vmap 不能与具有 custom_transforms
规则的函数的微分正确组合
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
return 2. * x
# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
return f(x), lambda g: 3. * x # 3 instead of 2
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # 3.
vmap(grad(f))(np.ones(4)) # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4)) # [2., 2., 2., 2.]
最后一个 grad-of-vmap 行有一个意想不到的结果!通常,应用 vmap
,或者实际上是任何非微分转换,都会产生删除自定义微分规则的效果。(当定义自定义 VJP 规则时,应用 jvp
会导致失败。)
存在此问题的原因是转换就像重写,而 vmap
转换实际上会重写该函数,使其不再调用新引入的具有自定义规则的原语(因此 grad
随后不会生成自定义规则的结果)。更详细地说,custom_transforms
机制的设置使得评估 f(x)
应用函数
{ lambda ; ; a.
let b = f_primitive a
in [b] }
其中 f_primitive
是一个新的原语(为每个 custom_transforms
函数引入,实际上是为每次调用该函数引入的),自定义 VJP 规则与其关联。当我们评估 grad(f)(x)
时,微分机制会遇到 f_primitive
并使用自定义规则对其进行处理。
但是,由于 f_primitive
对 vmap
是透明的,即 vmap
通过内联方式对 f_primitive
的定义进行操作,因此函数 vmap(f)
实际上是
{ lambda ; ; a.
let b = mul 2. a
in [b] }
换句话说,vmap
会根据其底层原语及其转换规则重写该函数,从而完全删除 f_primitive
。
更一般地说,由于 vmap(f)
的语义是根据对 f 的调用定义的,因此删除自定义导数规则在语义上是不一致的。也就是说,由于我们定义了
vmap(f)(xs) == np.stack([f(x) for x in xs])
我们必须有
jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))
但是,当 f
定义了自定义导数规则时,不会观察到此属性,因为自定义导数规则在右侧版本中使用,而不在左侧版本中使用。
此问题并非特定于 vmap
;它适用于所有转换,对于这些转换,转换函数 f
的语义是根据对函数 f
的调用定义的,而不是将其重写为另一个函数。mask
转换也属于此类。微分转换和假设的“所有一元函数都变为余弦”的转换不属于此类。
(额外的自定义规则(例如自定义 vmap
规则)之间的交互可能会变得更加复杂,这表明 custom_transforms
的问题框架过于宽泛。)
Python 的灵活性问题#
在 JAX 中,与 Autograd 和 PyTorch 类似,但与 TF1 不同的是,Python 函数的微分是在函数执行和追踪时进行的。这种行为让用户感到高兴的原因有几个。
首先也是最重要的是,它支持基于 pdb 的工作流程,例如用于检查数值或捕获 NaN。也就是说,用户可以使用标准的 Python 调试器和其他 Python 原生工具来调试他们的代码,甚至能够检查运行时值以了解示例中的数值行为并捕获像 NaN 这样的根本运行时错误。事实上,就在处理与此设计对应的 PR 时,特别是在 odeint
原语上,我多次使用运行时值检查来调试问题,这让我更加确信这是 Python 中的关键用户工作流程。一个特别方便的技巧,我在 JAX 和 Autograd 中都多次使用过,就是在自定义 VJP 规则中插入一个调试器断点,以便在反向传播中的特定点进入调试器。
其次,它允许对 Python 原生控制流进行微分。我们不确定在最终的软件制品中这种情况在实践中使用的频率如何,但是当用户第一次尝试使用 JAX 或 Autograd 时,他们通常会对这种自由感到印象深刻。这就是为什么我们将其放在 JAX 和 Autograd 的 README、幻灯片和演示的顶部的原因。放弃这种能力将是从 Autograd 的倒退。我们希望 JAX 拥有最好的自动微分。
然而,custom_transforms
机制不提供这种 Python 支持的灵活性。也就是说,因为它是在用户函数和自定义微分规则的 Python 代码中,根据预先生成的 jaxpr 实现的,所以像这样的代码会导致抽象值追踪错误
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
if x > 0:
return x
else:
return 0.
def f_vjp(x):
return ...
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # Error!
解决方案思路#
主要思路是 dougalm@ 已经使用 core.call
解决了这些问题。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务,定义为新的 Python 级调用原语(不添加到 jaxpr 语言;见下文)。这个新的调用原语与 core.call
一样,有一个与之关联的用户 Python 函数,但另外还有一个表示 JVP 规则的第二个 Python 可调用对象。让我们将这个新的调用原语称为 custom_jvp_call
。
像 vmap
这样的转换与 custom_jvp_call
的交互方式与 core.call
相同:它们有效地直接通过它,并应用于底层 Python 可调用对象。为了方便起见,使用原语的柯里化版本进行示意性书写,类似于 vmap
如何通过应用于要调用的函数来与 core.call
交互
vmap(call(f)) == call(vmap(f))
对于新的原语 custom_jvp_call
,我们只需将 vmap
应用于它所包含的两个函数
vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))
此行为意味着我们已解决了vmap 删除自定义 jvp 语义的问题。
jvp
转换的交互方式正如人们所期望的那样:它只是调用 f_jvp
,
jvp(call(f)) == call(jvp(f))
jvp(custom_jvp_call(f, f_jvp)) == f_jvp
因为 custom_jvp_call
的行为类似于 core.call
(而不是像 xla.xla_call
),因为它不会提高其输入的抽象级别(因为它没有延迟任何内容或将任何内容分阶段输出),这意味着我们已经解决了Python 灵活性问题:用户 Python 函数没有约束(除了 jvp
或 vjp
所需的通常的函数式编程约束)。
评估和编译呢?这是“退出” JAX 系统的两种方式,因为在这些步骤之后不能应用其他转换。因此,它们的规则很简单
eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))
eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))
换句话说,如果 JVP 规则尚未将 custom_jvp_call(f, f_jvp)
重写为 f_jvp
,那么当我们使用 eval
进行评估或使用 jit
分阶段输出到 XLA 时,永远不会应用微分,因此我们只是忽略 f_jvp
并像 core.call
一样运行。但是,由于接下来讨论的难题,custom_jvp_call
的部分评估规则必须稍微复杂一些,因为部分评估不仅仅用于使用 jit
分阶段输出到 XLA。
剩下的唯一难题与 “initial-style” jaxpr 形成原语(如 lax.scan
)及其转换规则有关。这些代表了一种不同类型的“分阶段输出到 jaxpr” 的方式,与编译不同,因为我们可以对分阶段输出的 jaxpr 执行其他转换。也就是说,当 lax.scan
形成 jaxpr 时,它不会退出转换系统,因为当我们对 lax.scan
应用 jvp 或 vmap 时,我们需要将其应用于 jaxpr 表示的函数。
陈述这个难题的另一种方式是,像 lax.scan
这样的 initial-style 原语依赖于在 jaxpr 和 Python 可调用对象之间来回传递的能力,同时保留语义。这也必须意味着保留自定义微分规则的语义。
解决方案是使用一些动态作用域:当我们将 initial-style 原语(如 lax_control_flow.py 中的原语)分阶段输出到 jaxpr 时,我们在全局跟踪状态上设置一个位。当该位被设置时,我们不使用 final-style 的 custom_jvp_call
原语,而是使用 initial-style 的 custom_jvp_call_jaxpr
原语,并且预先将函数 f
和 f_jvp
追踪到 jaxpr 中,以简化 initial-style 处理。custom_jvp_call_jaxpr
原语与其他方面与 final-style 版本类似。
(脚注:虽然在道德上我们在绑定 custom_jvp_call_jaxpr
之前为 f
和 f_jvp
都形成了 jaxpr,但是我们需要延迟 f_jvp
的 jaxpr 的形成,因为它可能会调用自定义 JVP 函数,因此急切的处理会导致无限递归。我们在一个 thunk 中延迟了该 jaxpr 的形成。)
如果我们放弃了 Python 的灵活性问题,我们可以只使用 custom_jvp_call_jaxpr
而不使用单独的 Python 级原语 custom_jvp_call
。
API#
函数 a -> b
的自定义 JVP 使用 (a, Ta) -> (b, T b)
函数指定
# f :: a -> b
@jax.custom_jvp
def f(x):
return np.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), np.cos(x) * t
f.defjvp(f_jvp)
(有趣的自动微分旁注:要使规则应用于高阶微分,必须在 f_jvp
的主体中调用 f
;这会排除 f
的内部和切线计算之间的一些工作共享。)
函数 a -> b
的自定义 VJP 使用 a -> (b, c)
前向传递函数与 (c, CT b) -> CT
后向传递函数配对指定
# f :: a -> b
@jax.custom_vjp
def f(x):
return np.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), np.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
return (cos_x * g,)
f.defvjp(f_fwd, f_bwd)
签名 a -> (b, CT b --o CT a)
在美学上更令人愉悦,但支持它会使实现更加复杂,并可能需要损害可表达性的需求。基本原因是 Python 可调用对象是不透明的(除非我们急切地将它们追踪到 jaxpr,这会带来表达性约束),在这种情况下,我们可能会返回一个在闭包内包含 vmap
追踪器的可调用对象,我们需要在前向传递期间了解这些追踪器。
我们可以添加方便的包装器,例如,一次为一个参数定义 JVP 规则(就像我们在内部为原语所做的那样)。但是因为这个提案已经足够复杂,所以我决定反对方便层;现在让我们保持最小化。
API 还有一些其他的花哨功能
输入和输出类型
a
、b
和c
可以是 jaxtypes 的任意 pytrees。当可以使用
inspect
模块将其解析为位置时,支持按名称(关键字参数)传递参数。这是对 Python 3 改进的以编程方式检查参数签名的能力的一个小实验。我相信它是合理的,但并不完整,这是一个很好的起点。(另请参阅 #2069。)可以使用
nondiff_argnums
将参数标记为不可微分,与jit
的static_argnums
一样,这些参数不必是 JAX 类型。我们需要为这些参数如何传递给规则设定一个约定。对于具有类型签名(d, a) -> b
的原始函数,其中d
表示不可微分的类型,JVP 规则的签名是(a, T a, d) -> T b
,而 VJP 规则的反向分量签名是(d, c, CT b) -> CT a
。也就是说,对于自定义 JVP 规则,不可微分的参数在primals
和tangents
之后按顺序传递,而在自定义 VJP 规则的反向函数中,则在残差之前按顺序传递。
实现说明#
更新
jax.experimental.odeint
由于
odeint
是一个相当复杂的自定义 VJP 规则用户,除了更新它以使其正常工作之外,我还想修改它,使其成为新自定义 VJP API 的规范用户,以此来测试该 API 是否良好。在此过程中,我对
odeint
的实现进行了其他改进删除 raveling/unraveling 样板代码
利用
lax.scan
删除索引更新逻辑在简单的摆锤基准测试中速度提高了 20+%
为自定义导数调用原语
custom_jvp_call
和custom_vjp_call
的每个变换添加了自定义绑定方法。它类似于core.call_bind
,但我们不处理环境跟踪:那些只是错误。添加了
custom_lin
原语,当使用自定义 VJP 规则时,它会被分阶段输出到线性 jaxpr 以进行转置。由于我们的反向模式自动微分被分解为线性化、部分求值和转置,我们的自定义 VJP 规则被分为两个单独的步骤进行处理:一个在线性化期间,另一个在转置期间。
线性化步骤,即
custom_vjp_call
的 JVP 规则,将custom_lin
应用于切线值;custom_lin
带有用户自定义的后向传递函数,并且作为原语,它只有转置规则。此机制在 #636 中有更详细的描述。
为了防止