可JAX转换函数的自定义JVP/VJP规则#

这是一份设计文档,解释了jax.custom_jvpjax.custom_vjp设计和实现背后的一些思考。有关面向用户的文档,请参阅教程笔记本

在 JAX 中定义微分规则有两种方式:

  1. 使用jax.custom_jvpjax.custom_vjp为已JAX可转换的Python函数定义自定义微分规则;以及

  2. 定义新的core.Primitive实例及其所有转换规则,例如调用来自其他系统的函数,如求解器、模拟器或通用数值计算系统。

本文档仅讨论#1。

目录#

目标#

我们希望**用户**能够自定义其代码的正向和/或反向模式微分行为。这种自定义

  1. 应具有*清晰一致的语义*,说明其工作原理以及如何与其他JAX转换组合;并且

  2. 应*灵活*支持AutogradPyTorch中的用例和工作流程,包括涉及Python控制流的微分以及NaN调试的工作流程。

作为**JAX开发者**,我们希望编写库函数,例如logitexpit,这些函数是根据其他原语定义的,但为了微分的目的,它们具有原语般的行为,即我们希望为它们定义自定义微分规则,这些规则可能在数值上更稳定或性能更好。特别是,我们不希望必须为logitexpit等函数指定vmapjit规则。

作为一个延伸目标,我们希望JAX成为为寻求对fixed_pointodeint等高阶函数添加自定义微分规则的高级用户提供良好环境;本设计文档不会解决这个问题,但我们希望确保我们不会排除该问题的良好解决方案。

也就是说,我们的主要目标是

  1. 解决vmap-移除-custom-jvp语义问题(#1249),以及

  2. 允许在自定义VJP中使用Python,例如调试NaN(#1275)。

次要目标是:3. 清理和简化用户体验(符号零、kwargs等),4. 在用户可以轻松添加fixed_pointodeintroot等的世界中取得进展。

总而言之,我们希望关闭#116#1097#1249#1275#1366#1723#1670#1875#1938,并替换custom_transforms机制(来自#636#818及其他)。

非目标#

以下是我们**不**打算实现的目标

  1. custom_transforms机制旨在提供一种转换通用的机制来定制行为,原则上(尽管在实践中从未真正使用过)允许用户自定义任何转换的规则,同时以某种方式继承其他转换的“透明”行为。**我们现在只解决微分(JVP和VJP,分别)的定制问题。** 微分是唯一实际被请求的用例,并且通过专注于微分,我们可以降低复杂性并提高灵活性。要控制所有规则,可以直接编写一个原语。

  2. **我们不会将数学美学置于用户端的灵活性和清晰性以及实现端的简洁性之上。**特别是,尽管自定义VJP签名a -> (b, CT b --o CT a)在数学上令人满意,但如果由于返回类型中的闭包而难以在Python机制中实现,我们可以接受更明确处理残差的方法。

  3. **序列化支持**,即能够加载已暂存的序列化程序表示并进一步进行JAX转换而不是仅仅进行评估的形式,目前超出了这些自定义JVP/VJP转换规则的范围。序列化不仅对想要保存计算表示(并在加载后转换它)的研究人员有用,而且对未来考虑也可能有用,例如在Python外部实现jaxpr转换,或者将jaxprs作为MLIR方言。通过将此定义为本设计的非目标,我们对Python可调用对象的存储位置限制更少。

主要问题描述#

vmap-移除-custom-jvp语义问题#

vmap-移除-custom-jvp语义问题在于vmap无法正确地与带有custom_transforms规则的函数的微分组合

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

最后一行grad-of-vmap的结果出乎意料!一般来说,应用vmap,或者任何非微分转换,都会导致移除自定义微分规则。(应用jvp会在定义了自定义VJP规则时导致失败。)

问题在于转换类似于重写,而vmap转换有效地重写了函数,使其不再调用新引入的、具有自定义规则的原语(因此grad也就不会产生自定义规则的结果)。更详细地说,custom_transforms机制的设置使得评估f(x)会应用该函数

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

其中f_primitive是一个新原语(为每个custom_transforms函数,实际上是为每次函数调用而引入),自定义VJP规则与其关联。当我们评估grad(f)(x)时,微分机制会遇到f_primitive并使用自定义规则对其进行处理。

然而,由于f_primitivevmap是*透明的*,即vmap对其定义进行操作(实际上是通过内联),所以函数vmap(f)实际上是

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

换句话说,vmap根据其底层原语及其转换规则重写函数,完全移除了f_primitive

更普遍地说,**因为vmap(f)的语义是根据对f的调用定义的,所以移除自定义导数规则在语义上是不一致的**。也就是说,既然我们定义了

vmap(f)(xs) == np.stack([f(x) for x in xs])

我们必须有

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

然而,当f定义了自定义导数规则时,该属性并未被遵守,因为自定义导数规则在右侧版本中被使用,而在左侧版本中则没有。

这个问题并非vmap特有;它适用于所有转换,这些转换的功能f的语义是根据对函数f的调用定义的,而不是将其重写为另一个函数。mask转换也属于此类。微分转换和假设的所有一元函数变为余弦的转换不属于此类。

(额外的自定义规则(如自定义vmap规则)之间的交互可能变得更加复杂,这表明custom_transforms的问题框架过于宽泛。)

Python灵活性问题#

在JAX中,如同在AutogradPyTorch中一样(但与TF1不同),Python函数的微分是在函数执行和追踪时进行的。这种行为让用户感到满意,原因有几点。

首先也是最重要的是,它支持基于pdb的工作流程,例如用于检查数值或捕获NaN。也就是说,用户可以使用标准Python调试器和其他Python原生工具来调试其代码,甚至能够检查运行时值以了解示例的数值行为并捕获根本性的运行时错误,如NaN。事实上,就在处理与本设计相关的PR时,特别是在odeint原语上,我多次使用运行时值检查来调试问题,这增加了我对此是Python中关键用户工作流程的信心。一个特别方便的技巧,我在JAX和Autograd中多次使用过,是能够在自定义VJP规则中插入调试器断点,以便在反向传播的特定点进入调试器。

其次,它允许对Python原生控制流进行微分。 我们不确定这在最终的软件产品中实际使用频率有多高,但当用户首次接触JAX或Autograd时,他们通常会被这种自由所震撼。我们将它列在JAX和Autograd的README、幻灯片和演示的顶部是有原因的。放弃这种能力将是Autograd的一个倒退。我们希望JAX拥有最好的自动微分能力。

然而,custom_transforms机制不提供这种Python支持的灵活性。也就是说,由于它通过预先从Python代码中形成用户函数和自定义微分规则的jaxpr来实现,像这样的代码会导致抽象值追踪错误

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

解决方案思路#

主要思想是,**dougalm@已经用core.call解决了这些问题**。也就是说,我们可以将为用户函数指定自定义JVP规则的任务,视为一个新的Python级调用原语(不添加到jaxpr语言中;见下文)。这个新的调用原语与core.call一样,关联一个用户Python函数,但额外还有一个表示JVP规则的第二个Python可调用对象。我们将这个新的调用原语称为custom_jvp_call

vmap这样的转换与custom_jvp_call的交互方式与core.call类似:它们实际上是直接通过它,并应用于底层的Python可调用对象。示意性地,为了方便起见,以原语的柯里化版本表示,类似于vmap通过应用于要调用的函数来与core.call交互

vmap(call(f)) == call(vmap(f))

对于新的原语custom_jvp_call,我们只需将vmap应用于它所包含的两个函数

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

这种行为意味着我们已经解决了vmap-移除-custom-jvp语义问题

jvp转换的交互方式正如人们所期望:它只调用f_jvp

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

由于custom_jvp_call的行为类似于core.call(而不是xla.xla_call),因为它不提高其输入的抽象级别(因为它没有延迟或暂存任何东西),这意味着我们已经解决了Python灵活性问题:对用户Python函数没有限制(超出jvpvjp所需的通常函数式编程限制)。

那么评估和编译呢?这两种方式都代表了“退出”JAX系统,即在这些步骤之后不能应用额外的转换。因此,它们的规则是微不足道的

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

换句话说,如果JVP规则尚未将custom_jvp_call(f, f_jvp)重写为f_jvp,那么当我们通过eval进行评估或通过jit暂存到XLA时,将永远不会应用微分,因此我们只需忽略f_jvp并像core.call一样行事。然而,由于接下来讨论的复杂性,custom_jvp_call的部分评估规则必须更加复杂,因为部分评估不仅仅用于通过jit暂存到XLA。

唯一剩下的麻烦与“初始样式”jaxpr形成原语(如lax.scan)及其转换规则有关。这些代表了与编译不同的另一种“暂存到jaxpr”的方式,因为我们可以对暂存的jaxpr执行额外的转换。也就是说,当lax.scan形成jaxpr时,它不会退出转换系统,因为当我们对lax.scan应用jvp或vmap时,我们需要将其应用于jaxpr表示的函数。

另一种表述这个麻烦的方式是,初始样式原语(如lax.scan)依赖于将jaxpr往返转换为Python可调用对象并返回的能力,同时保留语义。这意味着也必须保留自定义微分规则的语义。

解决方案是使用一点动态作用域:当我们为初始样式原语(如lax_control_flow.py中的那些)暂存到jaxpr时,我们在全局追踪状态上设置一个位。当该位被设置时,我们不使用最终样式custom_jvp_call原语,而是使用初始样式custom_jvp_call_jaxpr原语,并预先将函数ff_jvp追踪到jaxpr,以方便初始样式处理。custom_jvp_call_jaxpr原语在其他方面与最终样式版本相似。

(脚注:尽管从道德上讲,我们在绑定custom_jvp_call_jaxpr之前为ff_jvp都形成了jaxpr,但我们需要延迟f_jvp的jaxpr形成,因为它可能会调用自定义JVP函数,因此急切处理将导致无限递归。我们在一个thunk中延迟了该jaxpr的形成。)

如果我们放弃了Python灵活性问题,我们就可以只使用custom_jvp_call_jaxpr,而不需要独立的Python级原语custom_jvp_call

API#

一个a -> b函数的自定义JVP通过一个(a, Ta) -> (b, T b)函数指定

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(有趣的自动微分旁白:为了使规则适用于高阶微分,必须在f_jvp的函数体内调用f;这排除了f的内部和切线计算之间某些类型的工作共享。)

一个a -> b函数的自定义VJP由一个a -> (b, c)前向传播函数和一个(c, CT b) -> CTa反向传播函数指定

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

签名a -> (b, CT b --o CT a)在美学上更令人愉悦,但支持它会使实现更复杂,并且可能需要牺牲表达能力的需求。基本原因是Python可调用对象是不透明的(除非我们急切地将它们追踪到jaxpr,这会施加表达限制),在这种情况下,我们可能会返回一个在其闭包内部包含vmap追踪器的可调用对象,我们需要在正向传播过程中了解这些信息。

我们可以添加便捷包装器,例如一次为一个参数定义JVP规则(就像我们在内部为原语所做的那样)。但是由于这个提议已经足够复杂,我决定不增加便捷层;目前让我们保持简单。

API还有一些其他附加功能

  • 输入和输出类型abc可以是jaxtypes的任意pytrees。

  • 当参数可以通过inspect模块解析到位置时,支持通过名称(关键字参数)传递参数。这有点像是对Python 3改进的程序化检查参数签名能力的实验。我认为它是可靠但不完整的,这已经足够好了。(另请参阅#2069。)

  • 参数可以使用nondiff_argnums标记为不可微分,与jitstatic_argnums一样,这些参数不必是JAX类型。我们需要设定一个约定,这些参数如何传递给规则。对于类型签名为(d, a) -> b的原始函数,其中d表示不可微分类型,JVP规则的签名为(a, T a, d) -> T b,VJP规则的反向组件签名为(d, c, CT b) -> CT a。也就是说,不可微分参数在自定义JVP规则中按顺序在primalstangents之后传递,在自定义VJP规则的反向函数中在残差之前按顺序传递。

实现说明#

  • 更新了jax.experimental.odeint

    • 由于odeint是一个相当复杂的自定义VJP规则使用者,除了更新它使其能够正常工作之外,我还想对其进行修改,使其成为新自定义VJP API的规范用户,以此来测试API是否良好。

    • 在此过程中,我还对odeint实现进行了其他改进

      • 删除 raveling/unraveling 样板代码

      • 利用lax.scan来删除索引更新逻辑

      • 在简单摆锤基准测试中提速20%以上

  • 在自定义导数调用原语custom_jvp_callcustom_vjp_call的每个转换上添加了自定义绑定方法。它类似于core.call_bind,除了我们不处理环境追踪:那些只是错误。

  • 添加了custom_lin原语,当使用自定义VJP规则时,该原语会被暂存到线性jaxpr中进行转置。

    • 由于我们的反向模式自动微分被分解为线性化、部分求值和转置,因此我们的自定义 VJP 规则分两步处理:一步在线性化期间,一步在转置期间。

    • 线性化步骤,即custom_vjp_call的JVP规则,将custom_lin应用于切线值;custom_lin携带着用户的自定义反向传播函数,作为一个原语,它只具有转置规则。

    • 此机制在#636中有更详细的描述。

  • 为了防止