JAX 可转换函数自定义 JVP/VJP 规则#

这是一份设计文档,解释了 jax.custom_jvpjax.custom_vjp 设计和实现背后的一些思考。面向用户的文档,请参阅 教程 notebook

在 JAX 中定义微分规则有两种方式:

  1. 使用 jax.custom_jvpjax.custom_vjp 为已是 JAX 可转换的 Python 函数定义自定义微分规则;并且

  2. 定义新的 core.Primitive 实例及其所有变换规则,例如调用其他系统(如求解器、模拟器或通用数值计算系统)的函数。

本文档仅涉及 #1。

目录#

目标#

我们希望 **用户** 能够自定义其代码的正向和/或反向模式微分行为。这种自定义

  1. 在工作方式和与其他 JAX 变换的组合方式上应具有*清晰一致的语义*;并且

  2. 应*灵活*支持 AutogradPyTorch 中的用例和工作流程,包括涉及 Python 控制流微分和 NaN 调试的工作流程。

作为 **JAX 开发者**,我们希望编写一些库函数,例如 logitexpit,这些函数基于其他原始操作定义,但在微分目的上具有原始操作般的行为,即我们希望为它们定义自定义微分规则,这些规则可能在数值上更稳定或性能更好。特别是,我们不希望为 logitexpit 等函数指定 vmapjit 规则。

作为一项扩展目标,我们希望使 JAX 成为一个强大的环境,供高级用户为高阶函数(如 fixed_pointodeint 等)添加自定义微分规则;本设计文档不会解决该问题,但我们希望确保我们不会排除解决该问题的良好方案。

也就是说,我们的主要目标是

  1. 解决 vmap 移除自定义 jvp 语义问题(#1249),以及

  2. 允许自定义 VJP 中的 Python,例如用于调试 NaN(#1275)。

次要目标是 3. 优化和简化用户体验(符号零、关键字参数等)4. 向用户可以轻松添加 fixed_pointodeintroot 等功能的时代迈进。

总的来说,我们希望解决 #116#1097#1249#1275#1366#1723#1670#1875#1938,并替换 custom_transforms 机制(来自 #636#818 等)。

非目标#

以下是我们**不**打算实现的目标

  1. custom_transforms 机制旨在提供一个通用的变换机制来定制行为,原则上(尽管实际上很少使用)允许用户为任何变换定制规则,同时某种程度上继承其他变换的“透明”行为。**相反,我们只解决微分(JVP 和 VJP,分开)的自定义问题。** 微分是唯一实际被请求的情况,通过专门化微分,我们可以降低复杂性并提高灵活性。要控制所有规则,只需编写一个原始操作。

  2. **我们不会优先考虑数学美学**,而会优先考虑用户端的灵活性和清晰度,以及实现端的简洁性。特别是,虽然自定义 VJP 签名 a -> (b, CT b --o CT a) 在数学上是令人愉悦的,但如果因为它在返回类型中的闭包而在 Python 机制中难以实现,我们乐于采用一种更明确地处理残差的方法。

  3. **序列化支持**,即能够加载并进一步 JAX 变换(而不是仅评估)的暂存的序列化程序表示,目前不属于这些自定义 JVP/VJP 变换规则的范围。序列化不仅对想要保存其计算表示(并在加载后进行变换)的研究人员有用,而且对未来考虑(如在 Python 外部实现 jaxpr 变换,或将 jaxprs 作为 MLIR 方言)也有用。将此作为本次设计的非目标,我们可以对将 Python 可调用对象存放在何处施加更少的限制。

主要问题描述#

vmap 移除自定义 jvp 语义问题#

vmap 移除自定义 jvp 语义问题是:vmap 与具有 custom_transforms 规则的函数微分无法正确组合

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

最后一行 grad-of-vmap 产生了意外结果!一般来说,应用 vmap,或者实际上任何非微分变换,都会导致自定义微分规则被移除。(应用 jvp 会在定义了自定义 VJP 规则时失败。)

问题之所以存在,是因为变换就像重写,而 vmap 变换有效地重写了函数,使其不再调用新引入的具有自定义规则的原始操作(因此 grad 就不会产生自定义规则的结果)。更详细地说,custom_transforms 机制的设置使得评估 f(x) 时应用了该函数

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

其中 f_primitive 是一个为每个 custom_transforms 函数(实际上是每次函数调用)引入的新原始操作,自定义 VJP 规则与其关联。当评估 grad(f)(x) 时,微分机制遇到 f_primitive 并使用自定义规则进行处理。

然而,因为 f_primitivevmap 是*透明*的,即 vmap 操作(通过内联)f_primitive 的定义,所以函数 vmap(f) 实际上是

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

换句话说,vmap 会根据其底层原始操作及其变换规则重写函数,从而完全移除 f_primitive

更普遍地说,**因为 vmap(f) 的语义是基于对 f 的调用定义的,移除自定义导数规则在语义上是不一致的**。也就是说,由于我们定义了

vmap(f)(xs) == np.stack([f(x) for x in xs])

我们必须有

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

然而,当 f 定义了自定义导数规则时,这一属性并不成立,因为右侧版本使用了自定义导数规则,而左侧版本没有。

这个问题并非仅限于 vmap;它适用于所有变换,这些变换的语义是将函数 f 变换为基于函数 f 的调用,而不是重写为另一个函数。mask 变换也属于此类。微分变换和假设的“所有一元函数变成余弦”变换不属于此类。

(附加自定义规则(如自定义 vmap 规则)的交互可能会更加复杂,这表明 custom_transforms 的问题框架过于宽泛。)

Python 灵活性问题#

在 JAX 中,与 AutogradPyTorch 一样(但与 TF1 不同),Python 函数的微分是在函数执行和跟踪期间进行的。这种行为让用户感到高兴,原因如下。

**首先,也是最重要的,它支持基于 pdb 的工作流程,例如用于检查数值或捕获 NaN。** 也就是说,用户可以使用标准的 Python 调试器和其他 Python 原生工具来调试代码,甚至可以检查运行时值以了解示例上的数值行为,并捕获本质上是运行时错误(如 NaN)。事实上,在编写本设计对应的 PR 时,特别是在 odeint 原始操作时,我曾多次使用运行时值检查来调试问题,这增加了我对这是 Python 中关键用户工作流程的信心。一个特别方便的技巧是,我在 JAX 和 Autograd 中都多次使用过,能够在自定义 VJP 规则中插入调试器断点,以便在反向传播的特定点进入调试器。

**其次,它允许微分 Python 原生控制流。** 我们不确定在实际的最终软件制品中这种情况使用频率如何,但当用户初次接触 JAX 或 Autograd 时,他们常常对其自由度感到印象深刻。我们之所以在 JAX 和 Autograd 的 README、幻灯片和演示文稿的开头就包含这一点,是有原因的。放弃这种能力将是 Autograd 的倒退。我们希望 JAX 拥有最好的自动微分。

然而,custom_transforms 机制并未提供这种 Python 支持的灵活性。也就是说,由于它是通过为用户函数和自定义微分规则的 Python 代码进行预先的 jaxpr 形成来实现的,因此像这样的代码会导致抽象值跟踪错误

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

解决方案构想#

主要思想是 **dougalm@ 已经使用 core.call 解决了这些问题**。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务构架为一个新的 Python 级别调用原始操作(不添加到 jaxpr 语言;见下文)。这个新的调用原始操作有一个关联的用户 Python 函数,就像 core.call 一样,但此外还有一个代表 JVP 规则的第二个 Python 可调用对象。我们将其称为 custom_jvp_call

vmap 这样的变换与 custom_jvp_call 的交互方式与 core.call 类似:它们会直接通过它,并应用于底层的 Python 可调用对象。示意性地,为了方便起见,我们以带柯里化版本的原始操作来写,这与 vmap 如何通过应用于要调用的函数来与 core.call 交互类似

vmap(call(f)) == call(vmap(f))

对于新的原始操作 custom_jvp_call,我们只需将 vmap 应用于它包含的两个函数

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

这种行为意味着我们已经解决了vmap 移除自定义 jvp 语义问题

jvp 变换的交互方式正如人们所期望的那样:它只是调用 f_jvp

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

因为 custom_jvp_call 的作用类似于 core.call(而不是 xla.xla_call),因为它不提高其输入的抽象级别(因为它没有延迟任何内容或暂存任何内容),这意味着我们已经解决了Python 灵活性问题:对用户 Python 函数没有限制(除了 jvpvjp 通常所需的函数式编程约束)。

那么评估和编译呢?这是“退出”JAX 系统的两种方式,意味着在此之后不能应用任何其他变换。因此,它们的规则很简单

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

换句话说,如果 JVP 规则尚未将 custom_jvp_call(f, f_jvp) 重写为 f_jvp,那么当通过 eval 进行评估或通过 jit 暂存到 XLA 时,微分永远不会被应用,因此我们只会像 core.call 一样忽略 f_jvp。然而,由于接下来讨论的细微差别,custom_jvp_call 的部分求值规则必须更复杂一些,因为部分求值不仅仅用于通过 jit 暂存到 XLA。

唯一剩下的细微差别与“初始风格”的 jaxpr 形成原始操作(如 lax.scan)及其变换规则有关。这些代表了与编译不同的“暂存到 jaxpr”的类型,因为我们可以在暂存的 jaxpr 上执行额外的变换。也就是说,当 lax.scan 形成 jaxpr 时,它并不会退出变换系统,因为当我们对 lax.scan 应用 jvp 或 vmap 时,我们需要将其应用于 jaxpr 所代表的函数。

另一种陈述这种细微差别的方式是,像 lax.scan 这样的初始风格原始操作依赖于能够将 jaxpr 往返并返回 Python 可调用对象,同时保持语义。这意味着也必须保持自定义导数规则的语义。

解决方案是使用一些动态作用域:当我们为初始风格的原始操作(如 lax_control_flow.py 中的原始操作)暂存到 jaxpr 时,我们会设置全局跟踪状态中的一个标志。当设置了该标志时,我们使用初始风格的 custom_jvp_call_jaxpr 原始操作,而不是最终风格的 custom_jvp_call 原始操作,并预先跟踪函数 ff_jvp 的 jaxpr,以便更容易进行初始风格处理。custom_jvp_call_jaxpr 原始操作在其他方面与最终风格的版本相似。

(脚注:虽然在道德上我们在绑定 custom_jvp_call_jaxpr 之前为 ff_jvp 都形成了 jaxpr,但我们需要延迟 f_jvp 的 jaxpr 形成,因为它可能会调用自定义 JVP 函数,因此立即处理会导致无限递归。我们在 thunk 中延迟了该 jaxpr 的形成。)

如果我们放弃Python 灵活性问题,那么我们只需拥有 custom_jvp_call_jaxpr 而不需要独立的 Python 级别原始操作 custom_jvp_call

API#

一个 a -> b 函数的自定义 JVP 使用一个 (a, Ta) -> (b, T b) 函数来指定

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(有趣的自动微分提示:为了使规则适用于高阶微分,必须在 f_jvp 的主体中调用 f;这会排除 f 内部与切线计算之间某些类型的工作共享。)

一个 a -> b 函数的自定义 VJP 使用一个 a -> (b, c) 前向传播函数,与一个 (c, CT b) -> CT a 后向传播函数配对

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

签名 a -> (b, CT b --o CT a) 在数学上更具吸引力,但支持它会使实现更复杂,并且可能需要妥协可表达性要求。基本原因是 Python 可调用对象是不透明的(除非我们立即将其跟踪到 jaxpr,这会带来可表达性约束),在这种情况下,我们可能会返回一个在其闭包中带有 vmap 跟踪器的可调用对象,而我们在前向传播期间需要了解它。

我们可以添加便利包装器,例如一次定义单个参数的 JVP 规则(就像我们对原始操作内部进行的那样)。但是,由于此提案本身已经足够复杂,我决定不添加便利层;暂时保持最小化。

API 还有一些其他附加功能

  • 输入和输出类型 abc 可以是任意的 jaxtypes 复合结构(pytree)。

  • 当可以通过 inspect 模块解析为位置时,支持按名称(关键字参数)传递参数。这是对 Python 3 改进的程序化检查参数签名能力进行的一次实验。我认为这是可靠但不完整的,这已经是一个不错的结果。(另请参见 #2069。)

  • 可以使用 nondiff_argnums 将参数标记为不可微分,并且与 jitstatic_argnums 一样,这些参数不必是 JAX 类型。我们需要为这些参数如何传递给规则设定约定。对于类型签名为 (d, a) -> b 的原始函数,其中 d 代表不可微分类型,JVP 规则的签名是 (a, T a, d) -> T b,VJP 规则的反向组件签名是 (d, c, CT b) -> CT a。也就是说,不可微分参数按顺序在自定义 JVP 规则的 primalstangents 之后传递,并在自定义 VJP 规则的反向函数中在残差之前传递。

实现说明#

  • 更新了 jax.experimental.odeint

    • 由于 odeint 是一个复杂的自定义 VJP 规则用户,除了使其能够正常工作之外,我还想将其修改为新自定义 VJP API 的规范用户,以测试该 API 是否良好。

    • 在此过程中,我对 odeint 的实现进行了其他改进

      • 移除展平/反展平的样板代码

      • 利用 lax.scan 移除索引更新逻辑

      • 在简单摆锤基准测试上速度提升 20% 以上

  • 为自定义导数调用原始操作 custom_jvp_callcustom_vjp_call 在每个变换上添加了一个自定义绑定方法。它类似于 core.call_bind,只是我们不处理环境跟踪(env traces):那些只是错误。

  • 添加了 custom_lin 原始操作,当使用自定义 VJP 规则时,它会被暂存为线性 jaxpr 以进行转置。

    • 由于我们的反向模式自动微分分解为线性化、部分求值和转置,我们的自定义 VJP 规则在两个单独的步骤中处理:一个在线性化期间,一个在转置期间。

    • 线性化步骤,即 custom_vjp_call 的 JVP 规则,将 custom_lin 应用于切线值;custom_lin 携带用户自定义的反向传播函数,并且作为一个原始操作,它只有一个转置规则。

    • 此机制在 #636 中有更详细的描述。

  • 为防止