JAX 可转换函数的自定义 JVP/VJP 规则#

这是一份设计文档,解释了 jax.custom_jvpjax.custom_vjp 的设计和实现背后的一些思考。有关面向用户的文档,请参阅教程笔记本

在 JAX 中定义微分规则有两种方法

  1. 使用 jax.custom_jvpjax.custom_vjp 为已可 JAX 转换的 Python 函数定义自定义微分规则;以及

  2. 定义新的 core.Primitive 实例以及它们的所有转换规则,例如调用来自其他系统(如求解器、模拟器或通用数值计算系统)的函数。

本文档仅关于 #1。

目录#

目标#

我们希望用户自定义其代码的前向和/或反向模式微分行为。这种自定义

  1. 在工作方式以及如何与其他 JAX 转换组合方面,应具有清晰且一致的语义;并且

  2. 在支持用例和工作流程方面应灵活,如 AutogradPyTorch 中那样,包括涉及 Python 控制流微分和 NaN 调试的工作流程。

作为 JAX 开发者,我们希望编写库函数,例如 logitexpit,这些函数根据其他原语定义,但出于微分的目的,具有类似原语的行为,因为我们希望为它们定义自定义微分规则,这些规则可能在数值上更稳定或性能更好。特别是,我们不想为 logitexpit 等函数指定 vmapjit 规则。

作为一个延伸目标,我们希望使 JAX 成为高级用户添加高阶函数(如 fixed_pointodeint 等)自定义微分规则的绝佳环境;本文档不会解决该问题,但我们希望确信我们不会排除该问题的良好解决方案。

也就是说,我们的主要目标是

  1. 解决 vmap-移除-自定义-jvp 语义问题 (#1249),以及

  2. 允许在自定义 VJP 中使用 Python,例如调试 NaN (#1275)。

次要目标是 3. 清理和简化用户体验(符号零、kwargs 等)4. 朝着用户可以轻松添加 fixed_pointodeintroot 等的世界迈进。

总的来说,我们希望关闭 #116#1097#1249#1275#1366#1723#1670#1875#1938,并替换 custom_transforms 机制(来自 #636#818 和其他)。

非目标#

以下是我们打算实现的目标

  1. custom_transforms 机制旨在提供一种转换通用的机制来自定义行为,原则上(尽管实际上从未真正使用过)允许用户自定义任何转换的规则,同时以某种方式继承其他转换的“透明”行为。我们改为仅解决微分(JVP 和 VJP,分别)的自定义问题。微分是唯一实际请求的情况,通过专门针对微分,我们可以降低复杂性并提高灵活性。要控制所有规则,只需编写一个原语即可。

  2. 我们不会优先考虑数学美感,而是用户端的灵活性和清晰度,以及实现端的简单性。特别是,虽然自定义 VJP 签名 a -> (b, CT b --o CT a) 在数学上令人愉悦,但如果由于返回类型中的闭包而难以在 Python 机制中实现,那么我们可以更明确地处理残差。

  3. 序列化支持,即 staged-out 序列化程序表示可以被加载并进一步进行 JAX 转换,而不仅仅是评估的形式,目前不在这些自定义 JVP/VJP 转换规则的范围内。序列化不仅对于想要保存其计算的某种表示形式(并在加载后对其进行转换)的研究人员可能有用,而且对于未来的考虑(例如在 Python 外部实现 jaxpr 转换,或将 jaxpr 作为 MLIR 方言)也可能有用。通过将此定义为本设计的非目标,我们在可以存放 Python 可调用对象的位置上的约束更少。

主要问题描述#

vmap-移除-自定义-jvp 语义问题#

vmap-移除-自定义-jvp 语义问题是 vmap 不能与具有 custom_transforms 规则的函数的微分正确组合

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

最后一个 grad-of-vmap 行具有意外结果!一般来说,应用 vmap,或者任何非微分转换,都具有移除自定义微分规则的效果。(当定义了自定义 VJP 规则时,应用 jvp 会导致失败。)

问题存在的原因是转换就像重写,vmap 转换有效地重写了函数,使其不再调用为其存在自定义规则的新引入的原语(因此 grad 然后不会产生自定义规则的结果)。更详细地说,custom_transforms 机制设置使得评估 f(x) 应用函数

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

其中 f_primitive 是一个新的原语(为每个 custom_transforms 函数引入,实际上是为函数的每次调用引入),自定义 VJP 规则与该原语关联。当我们评估 grad(f)(x) 时,微分机制遇到 f_primitive 并使用自定义规则处理它。

然而,由于 f_primitivevmap透明的,从某种意义上说,vmapf_primitive 的定义进行操作(有效地通过内联),函数 vmap(f) 实际上是

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

换句话说,vmap 根据其底层原语及其转换规则重写函数,完全移除了 f_primitive

更一般地说,由于 vmap(f) 的语义是根据对 f 的调用定义的,因此移除自定义导数规则在语义上是不一致的。也就是说,由于我们定义

vmap(f)(xs) == np.stack([f(x) for x in xs])

我们必须有

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

然而,当 f 定义了自定义导数规则时,不会观察到此属性,因为自定义导数规则在右侧版本中使用,但在左侧版本中未使用。

这个问题并非特定于 vmap;它适用于所有转换,对于这些转换,转换函数 f 的语义是根据对函数 f 的调用定义的,而不是将其重写为另一个函数。mask 转换也属于此类。微分转换和假设的所有一元函数都变成余弦变换不在此类中。

(额外的自定义规则(如自定义 vmap 规则)之间的交互可能会变得更加复杂,这表明 custom_transforms 的问题框架过于宽泛。)

Python 灵活性问题#

在 JAX 中,如 AutogradPyTorch 中一样,但与 TF1 不同,Python 函数的微分是在函数执行和跟踪时执行的。这种行为让用户感到高兴,原因有几个。

首先也是最重要的是,它启用了基于 pdb 的工作流程,例如用于检查数值或捕获 NaN。 也就是说,用户可以采用标准的 Python 调试器和其他 Python 原生工具来调试他们的代码,甚至能够检查运行时值以了解示例的数值行为并捕获根本的运行时错误(如 NaN)。事实上,就在处理与此设计相对应的 PR 时,尤其是在 odeint 原语上,我多次使用运行时值检查来调试问题,这增加了我对这确实是 Python 中的关键用户工作流程的信心。一个特别方便的技巧,我在 JAX 和 Autograd 中多次使用过,是在自定义 VJP 规则中插入调试器断点,以便在反向传播的特定点进入调试器。

其次,它允许 Python 原生控制流的微分。 我们不确定这在最终的软件工件中使用的频率有多高,但当用户首次接触 JAX 或 Autograd 时,他们通常会对这种自由度印象深刻。我们有理由将其包含在我们的 JAX 和 Autograd README、幻灯片和演示文稿的顶部。放弃这种能力将是从 Autograd 向后退一步。我们希望 JAX 具有最佳的自动微分功能。

然而,custom_transforms 机制不提供这种 Python 支持的灵活性。也就是说,因为它是在用户函数和自定义微分规则的 Python 代码中预先形成 jaxpr 的基础上实现的,所以像这样的代码会导致抽象值跟踪错误

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

解决方案思路#

主要思想是 dougalm@ 已经使用 core.call 解决了这些问题。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务构建为新的 Python 级别调用原语(不会添加到 jaxpr 语言中;见下文)。这个新的调用原语有一个与其关联的用户 Python 函数,就像 core.call 一样,但此外还有一个表示 JVP 规则的第二个 Python 可调用对象。让我们将这个新的调用原语称为 custom_jvp_call

vmap 这样的转换与 custom_jvp_call 的交互方式与 core.call 相同:它们有效地直接传递它并应用于底层 Python 可调用对象。示意性地,为了方便起见,用原语的柯里化版本编写,类似于 vmap 通过应用于要调用的函数与 core.call 交互的方式

vmap(call(f)) == call(vmap(f))

对于新的原语 custom_jvp_call,我们只需将 vmap 应用于它包含的两个函数

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

这种行为意味着我们已经解决了 vmap-移除-自定义-jvp 语义问题

jvp 转换的交互方式正如人们可能期望的那样:它只是调用 f_jvp

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

因为 custom_jvp_call 的行为类似于 core.call(而不是像 xla.xla_call),因为它没有提高其输入的抽象级别(因为它没有延迟任何内容或 staged-out 任何内容),这意味着我们已经解决了 Python 灵活性问题:对用户 Python 函数没有任何约束(高于 jvpvjp 所需的常用函数式编程约束)。

评估和编译呢?这是“退出” JAX 系统的两种方法,从某种意义上说,在这些步骤之后不能应用其他转换。因此,它们的规则是微不足道的

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

换句话说,如果 JVP 规则尚未将 custom_jvp_call(f, f_jvp) 重写为 f_jvp,当我们通过 eval 进行评估或通过 jit staged-out 到 XLA 时,永远不会应用微分,因此我们只需忽略 f_jvp,其行为就像 core.call。但是,由于接下来讨论的曲折,custom_jvp_call 的部分评估规则必须稍微复杂一些,因为部分评估不仅仅用于通过 jit staged-out 到 XLA。

唯一剩下的曲折与“初始样式”jaxpr 形成原语(如 lax.scan)及其转换规则有关。这些代表了一种与编译不同的“staged-out 到 jaxpr”类型,因为我们可以对 staged-out jaxpr 执行额外的转换。也就是说,当 lax.scan 形成 jaxpr 时,它不会退出转换系统,因为当我们对 lax.scan 应用 jvp 或 vmap 时,我们需要将其应用于 jaxpr 表示的函数。

另一种描述曲折的方式是,初始样式原语(如 lax.scan)依赖于往返于 jaxpr 和 Python 可调用对象的能力,同时保持语义。这也必须意味着保持自定义微分规则语义。

解决方案是使用一点动态作用域:当我们为 initial-style 原始函数(例如 lax_control_flow.py 中的那些函数)分段输出到 jaxpr 时,我们在全局跟踪状态上设置一个位。当该位被设置时,我们不使用 final-style 的 custom_jvp_call 原始函数,而是使用 initial-style 的 custom_jvp_call_jaxpr 原始函数,并预先将函数 ff_jvp 跟踪到 jaxpr,以便更轻松地进行 initial-style 处理。custom_jvp_call_jaxpr 原始函数在其他方面与 final-style 版本类似。

(脚注:虽然在表面上我们在绑定 custom_jvp_call_jaxpr 之前为 ff_jvp 都形成了 jaxpr,但我们需要延迟 f_jvp 的 jaxpr 的形成,因为它可能会调用 custom-JVP 函数,因此 eager 处理会导致无限递归。我们以 thunk 的形式延迟 jaxpr 的形成。)

如果我们放弃Python 灵活性问题,我们就可以只使用 custom_jvp_call_jaxpr,而不需要单独的 Python 级别原始函数 custom_jvp_call

API#

函数 a -> b 的自定义 JVP 由函数 (a, Ta) -> (b, T b) 指定

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(有趣的自动微分旁注:为了使规则适用于高阶微分,必须在 f_jvp 的主体中调用 f;这排除了一些在 f 的内部和切线计算之间进行工作共享的方法。)

函数 a -> b 的自定义 VJP 由 a -> (b, c) 前向传递函数与 (c, CT b) -> CT a 后向传递函数配对指定

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

签名 a -> (b, CT b --o CT a) 在美学上更令人愉悦,但支持它会使实现更加复杂,并可能需要牺牲可表达性方面的期望。基本原因是 Python 可调用对象是不透明的(除非我们急切地将它们跟踪到 jaxpr,这会带来表达性约束),在这种情况下,我们可能会返回一个可调用对象,其闭包内有 vmap 跟踪器,我们需要在前向传递期间了解这些跟踪器。

我们可以添加便利的包装器,例如,一次为一个参数定义 JVP 规则(就像我们在内部为原始函数所做的那样)。但由于这个提案已经够复杂了,所以我决定反对便利层;现在让我们保持最小化。

API 还有一些其他的附加功能

  • 输入和输出类型 abc 可以是 jaxtypes 的任意 pytree。

  • 当可以使用 inspect 模块将命名参数(关键字参数)解析为位置参数时,支持命名参数(关键字参数)。这是对 Python 3 改进的以编程方式检查参数签名能力的一个实验。我相信它是合理的,但并不完整,这是一个不错的位置。(另请参阅 #2069。)

  • 可以使用 nondiff_argnums 将参数标记为不可微分,并且与 jitstatic_argnums 一样,这些参数不必是 JAX 类型。我们需要为如何将这些参数传递给规则设置一个约定。对于具有类型签名 (d, a) -> b 的原始函数,其中 d 表示不可微分类型,JVP 规则的签名是 (a, T a, d) -> T b,而 VJP 规则的反向组件签名是 (d, c, CT b) -> CT a。也就是说,对于自定义 JVP 规则,不可微分参数在 primalstangents 之后按顺序传递,并且在自定义 VJP 规则的反向函数中,在残差之前按顺序传递。

实现说明#

  • 更新了 jax.experimental.odeint

    • 由于 odeint 是自定义 VJP 规则的一个相当复杂的用户,除了更新它以使其完全工作之外,我还想修改它,使其成为新的自定义 VJP API 的规范用户,以此来测试该 API 是否良好。

    • 在此过程中,我对 odeint 的实现进行了其他改进

      • 删除 raveling/unraveling 样板代码

      • 利用 lax.scan 删除索引更新逻辑

      • 在简单摆基准测试中加速 20+%

  • 为自定义导数调用原始函数 custom_jvp_callcustom_vjp_call 在每个变换上添加了自定义绑定方法。它类似于 core.call_bind,但我们不处理 env 跟踪:那些只是错误。

  • 添加了 custom_lin 原始函数,它被分段输出到线性 jaxpr 中,以便在使用自定义 VJP 规则时进行转置。

    • 由于我们的反向模式自动微分分解为线性化、部分求值和转置,因此我们的自定义 VJP 规则分两个独立的步骤进行处理:一个在线性化期间,另一个在转置期间。

    • 线性化步骤,即 custom_vjp_call 的 JVP 规则,将 custom_lin 应用于切线值;custom_lin 带有用户自定义的反向传递函数,并且作为原始函数,它仅具有转置规则。

    • 此机制在 #636 中有更详细的描述。

  • 防止