自定义导数规则#

Open in Colab Open in Kaggle

在 JAX 中定义微分规则有两种方式:

  1. 使用 jax.custom_jvpjax.custom_vjp 为已经可以进行 JAX 转换的 Python 函数定义自定义微分规则;以及

  2. 定义新的 core.Primitive 实例及其所有转换规则,例如调用其他系统(如求解器、模拟器或通用数值计算系统)的函数。

本教程关注 #1。要了解 #2,请参阅 关于添加原语的教程

有关 JAX 自动微分 API 的介绍,请参阅 自动微分手册。本教程假定您熟悉 jax.jvpjax.grad,以及 JVP 和 VJP 的数学含义。

概述#

使用 jax.custom_jvp 自定义 JVP#

import jax.numpy as jnp
from jax import custom_jvp

@custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out
from jax import jvp, grad

print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
# Equivalent alternative using the defjvps convenience wrapper

@custom_jvp
def f(x, y):
  return jnp.sin(x) * y

f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405

使用 jax.custom_vjp 自定义 VJP#

from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  # Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405

示例问题#

为了了解 jax.custom_jvpjax.custom_vjp 要解决的问题,让我们来看几个例子。对 jax.custom_jvpjax.custom_vjp API 的更全面介绍将在下一节中进行。

数值稳定性#

jax.custom_jvp 的一个应用是提高微分的数值稳定性。

假设我们要编写一个名为 log1pexp 的函数,它计算 \(x \mapsto \log ( 1 + e^x )\)。我们可以使用 jax.numpy 来实现它

def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)

由于它是使用 jax.numpy 编写的,因此它可以进行 JAX 转换

from jax import jit, grad, vmap

print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5       0.7310586 0.8807971]

但这里隐藏着一个数值稳定性问题

print(grad(log1pexp)(100.))
nan

这似乎不对!毕竟, \(x \mapsto \log (1 + e^x)\) 的导数是 \(x \mapsto \frac{e^x}{1 + e^x}\),所以对于很大的 \(x\) 值,我们期望其值约为 1。

我们可以通过查看梯度计算的 jaxpr 来更深入地了解正在发生的事情

from jax import make_jaxpr

make_jaxpr(grad(log1pexp))(100.)
{ lambda ; a:f32[]. let
    b:f32[] = exp a
    c:f32[] = add 1.0:f32[] b
    _:f32[] = log c
    d:f32[] = div 1.0:f32[] c
    e:f32[] = mul d b
  in (e,) }

逐步分析 jaxpr 的计算过程,我们可以看到最后一行将涉及乘以浮点数计算会四舍五入为 0 和 \(\infty\) 的值,这永远不是个好主意。也就是说,对于很大的 x,我们实际上是在计算 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x),这实际上变成了 0. * jnp.inf

与其产生如此大和小的值,期望浮点数无法始终提供的抵消,我们宁愿将导数函数表示为一个更具数值稳定性的程序。特别是,我们可以编写一个程序,它更接近于计算等效的数学表达式 \(1 - \frac{1}{1 + e^x}\),而没有抵消。

这个问题很有趣,因为尽管我们对 log1pexp 的定义已经可以进行 JAX 微分(并进行 jitvmap 等转换),但我们对应用标准自动微分规则到构成 log1pexp 的原语并组合结果的效果不满意。相反,我们希望指定整个函数 log1pexp 应该如何被微分,作为一个整体,从而更好地安排这些指数。

这是自定义导数规则应用于已可进行 JAX 转换的 Python 函数的一种应用:指定复合函数应如何微分,同时仍使用其原始 Python 定义进行其他转换(如 jitvmap 等)。

这是使用 jax.custom_jvp 的解决方案

from jax import custom_jvp

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = log1pexp(x)
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
  return ans, ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5       0.7310586 0.8807971]

这是一个 defjvps 便利包装器,用于表达相同的内容

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0
3.0485873
0.95257413
[0.5       0.7310586 0.8807971]

强制执行微分约定#

一个相关的应用是强制执行微分约定,可能是在边界处。

考虑函数 \(f : \mathbb{R}_+ \to \mathbb{R}_+\),其中 \(f(x) = \frac{x}{1 + \sqrt{x}}\),我们取 \(\mathbb{R}_+ = [0, \infty)\)。我们可以像这样实现 \(f\) 的程序

def f(x):
  return x / (1 + jnp.sqrt(x))

作为 \(\mathbb{R}\)(整个实数轴)上的数学函数,\(f\) 在零处不可微分(因为定义导数的极限从左侧不存在)。相应地,自动微分会产生一个 nan

print(grad(f)(0.))
nan

但从数学上讲,如果我们把 \(f\) 看作是 \(\mathbb{R}_+\) 上的函数,那么它在 0 处是可微的 [Rudin 的《数学分析原理》定义 5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和示例 10.1.6]。或者,我们可以说,按照惯例,我们希望考虑右侧的定向导数。因此,Python 函数 grad(f)0.0 处返回一个有意义的值,即 1.0。默认情况下,JAX 的微分机制假定所有函数都在 \(\mathbb{R}\) 上定义,因此不会在此处产生 1.0

我们可以使用自定义 JVP 规则!特别是,我们可以根据 \(\mathbb{R}_+\) 上的导数函数 \(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) 来定义 JVP 规则,

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))

@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
  return ans, ans_dot
print(grad(f)(0.))
1.0

这是便利包装器版本

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))

f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
print(grad(f)(0.))
1.0

梯度裁剪#

虽然在某些情况下我们希望表达一个数学微分计算,但在其他情况下,我们甚至可能希望采取一种脱离数学的方法来调整自动微分执行的计算。一个典型的例子是后向模式梯度裁剪。

对于梯度裁剪,我们可以使用 jnp.clip 结合 jax.custom_vjp 仅后向模式规则

from functools import partial
from jax import custom_vjp

@custom_vjp
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save bounds as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # use None to indicate zero cotangents for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
import matplotlib.pyplot as plt
from jax import vmap

t = jnp.linspace(0, 10, 1000)

plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x7a99c512a8a0>]
../_images/085d4cd583c83b63eb11f23606040a988b6014181c2f6e4d4d2de8d2f6ebfa67.png
def clip_sin(x):
  x = clip_gradient(-0.75, 0.75, x)
  return jnp.sin(x)

plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x7a99c4ef49e0>]
../_images/4e5b2aed7165658c9700d6bb3015332c98918ec8dfc7733c6fce7776df4955c6.png

Python 调试#

另一个受开发工作流程而非数值驱动的应用是,在后向模式自动微分的后向传播中设置一个 pdb 调试器断点。

当尝试追踪 nan 运行时错误的原因,或者只是仔细检查传播的共切向量(梯度)值时,在后向传播中与原始计算的特定点相对应的点插入调试器会很有用。您可以使用 jax.custom_vjp 来实现这一点。

我们将在下一节中推迟一个例子。

迭代实现的隐函数微分#

这个例子涉及大量的数学细节!

jax.custom_vjp 的另一个应用是后向模式微分那些可以进行 JAX 转换(通过 jitvmap 等)但由于某种原因无法高效进行 JAX 微分的函数,原因可能是它们涉及 lax.while_loop。(无法生成高效计算 XLA HLO 循环的后向模式导数的 XLA HLO 程序,因为这将需要一个内存使用无界的程序,而这在 XLA HLO 中无法表示,至少在没有通过 infeed/outfeed 进行副作用交互的情况下是如此。)

例如,考虑这个 fixed_point 过程,它通过在 while_loop 中迭代应用一个函数来计算不动点

from jax.lax import while_loop

def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6

  def body_fun(carry):
    _, x = carry
    return x, f(a, x)

  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star

这是一个数值求解方程 \(x = f(a, x)\) 的迭代过程,通过迭代 \(x_{t+1} = f(a, x_t)\) 直到 \(x_{t+1}\) 非常接近 \(x_t\)。结果 \(x^*\) 取决于参数 \(a\),因此我们可以认为存在一个函数 \(a \mapsto x^*(a)\),它由方程 \(x = f(a, x)\) 隐式定义。

我们可以使用 fixed_point 来运行迭代过程直至收敛,例如运行牛顿法来计算平方根,而只执行加法、乘法和除法

def newton_sqrt(a):
  update = lambda a, x: 0.5 * (x + a / x)
  return fixed_point(update, a, a)
print(newton_sqrt(2.))
1.4142135

我们也可以 vmapjit 这个函数

print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1.        1.4142135 1.7320509 2.       ]

由于 while_loop,我们无法应用后向模式自动微分,但事实证明我们无论如何也不想这样做:与其对 fixed_point 及其所有迭代的实现进行微分,不如利用数学结构来实现一个内存效率更高(在这种情况下,FLOP 效率也更高!)的方法。我们可以改用隐函数定理 [Bertsekas 的《非线性规划》第二版附录 A.25],该定理(在某些条件下)保证了我们即将使用的数学对象的存在。本质上,我们在解处进行线性化,并迭代求解这些线性方程来计算我们想要的导数。

再次考虑方程 \(x = f(a, x)\) 和函数 \(x^*\)。我们想计算向量-雅可比乘积,如 \(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\)

至少在我们要微分的点 \(a_0\) 周围的一个开邻域内,假设方程 \(x^*(a) = f(a, x^*(a))\) 对所有 \(a\) 都成立。由于两侧作为 \(a\) 的函数是相等的,它们的导数也必须相等,所以让我们对两边进行微分

\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).

\(A = \partial_1 f(a_0, x^*(a_0))\)\(B = \partial_0 f(a_0, x^*(a_0))\),我们可以更简单地写出我们想要的量:

\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),

或者,重新排列一下,

\(\qquad \partial x^*(a_0) = (I - A)^{-1} B\).

这意味着我们可以计算向量-雅可比乘积,例如

\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B\),

其中 \(w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}\),或者等价地 \(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\),或者等价地 \(w^\mathsf{T}\) 是映射 \(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\) 的不动点。最后一个描述为我们提供了一种方法,可以通过调用 fixed_point 来编写 fixed_point 的 VJP!此外,在将 \(A\)\(B\) 展开后,我们可以看到我们只需要评估 f\((a_0, x^*(a_0))\) 处的 VJP。

总而言之

from jax import vjp

@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6

  def body_fun(carry):
    _, x = carry
    return x, f(a, x)

  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star

def fixed_point_fwd(f, a, x_init):
  x_star = fixed_point(f, a, x_init)
  return x_star, (a, x_star)

def fixed_point_rev(f, res, x_star_bar):
  a, x_star = res
  _, vjp_a = vjp(lambda a: f(a, x_star), a)
  a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
                             (a, x_star, x_star_bar),
                             x_star_bar))
  return a_bar, jnp.zeros_like(x_star)

def rev_iter(f, packed, u):
  a, x_star, x_star_bar = packed
  _, vjp_x = vjp(lambda x: f(a, x), x_star)
  return x_star_bar + vjp_x(u)[0]

fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
-0.088388346

我们可以通过对 jnp.sqrt 进行微分来检查我们的答案,它使用了完全不同的实现

print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.))
0.35355338
-0.08838835

这种方法的一个限制是参数 f 不能闭合任何涉及微分的值。也就是说,您可能会注意到,我们在 fixed_point 的参数列表中保留了显式参数 a。对于这种情况,请考虑使用低级原语 lax.custom_root,它允许对闭合变量进行微分,并带有自定义根查找函数。

基本使用 jax.custom_jvpjax.custom_vjp API#

使用 jax.custom_jvp 定义前向模式(以及间接的,后向模式)规则#

这是一个使用 jax.custom_jvp 的典型基本示例,其中注释使用了 类似 Haskell 的类型签名

from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
  return jnp.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t

f.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
from jax import jvp

print(f(3.))

y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
-0.9899925

简单来说,我们从一个原始函数 f 开始,该函数接受类型为 a 的输入并产生类型为 b 的输出。我们将其关联一个 JVP 规则函数 f_jvp,该函数接受代表类型为 a 的原始输入的参数对,以及代表类型为 T a 的相应切向量输入,并产生代表类型为 b 的原始输出和类型为 T b 的切向量输出的参数对。切向量输出应该是切向量输入的线性函数。

您也可以将 f.defjvp 用作装饰器,例如

@custom_jvp
def f(x):
  ...

@f.defjvp
def f_jvp(primals, tangents):
  ...

即使我们只定义了一个 JVP 规则而没有定义 VJP 规则,我们也可以对 f 使用前向模式和后向模式微分。JAX 将自动转置我们自定义 JVP 规则中的切向量上的线性计算,从而以与我们手动编写规则一样高效的方式计算 VJP。

from jax import grad

print(grad(f)(3.))
print(grad(grad(f))(3.))
-0.9899925
-0.14112

为了使自动转置能够正常工作,JVP 规则的输出切向量必须是输入切向量的线性函数。否则将引发转置错误。

多个参数的工作方式如下

@custom_jvp
def f(x, y):
  return x ** 2 * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
  return primal_out, tangent_out
print(grad(f)(2., 3.))
12.0

defjvps 便利包装器允许我们分别定义每个参数的 JVP,然后将结果分开计算再求和

@custom_jvp
def f(x):
  return jnp.sin(x)

f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
print(grad(f)(3.))
-0.9899925

这是带有多个参数的 defjvps 示例

@custom_jvp
def f(x, y):
  return x ** 2 * y

f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0

作为一种简写,使用 defjvps,您可以传递一个 None 值来表示特定参数的 JVP 为零

@custom_jvp
def f(x, y):
  return x ** 2 * y

f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          None)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0

使用关键字参数调用 jax.custom_jvp 函数,或使用默认参数编写 jax.custom_jvp 函数定义,都是允许的,只要它们可以根据标准库 inspect.signature 机制检索到的函数签名明确映射到位置参数。

当您不进行微分时,函数 f 的调用方式与未被 jax.custom_jvp 装饰时相同

@custom_jvp
def f(x):
  print('called f!')  # a harmless side-effect
  return jnp.sin(x)

@f.defjvp
def f_jvp(primals, tangents):
  print('called f_jvp!')  # a harmless side-effect
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t
from jax import vmap, jit

print(f(3.))
called f!
0.14112
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.))
called f!
[0.         0.84147096 0.9092974 ]
called f!
0.14112

自定义 JVP 规则在微分时被调用,无论是前向还是后向

y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
-0.9899925
print(grad(f)(3.))
called f_jvp!
called f!
-0.9899925

请注意,f_jvp 调用 f 来计算原始输出。在更高阶微分的上下文中,每次应用微分转换时,都将使用自定义 JVP 规则,当且仅当该规则调用原始 f 来计算原始输出时。(这代表了一种基本权衡,即我们无法在规则中使用 f 计算的中间值,同时 还能让该规则在所有阶的更高阶微分中都适用。)

grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
Array(-0.14112, dtype=float32, weak_type=True)

您可以使用 Python 控制流与 jax.custom_jvp

@custom_jvp
def f(x):
  if x > 0:
    return jnp.sin(x)
  else:
    return jnp.cos(x)

@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  if x > 0:
    return ans, 2 * x_dot
  else:
    return ans, 3 * x_dot
print(grad(f)(1.))
print(grad(f)(-1.))
2.0
3.0

使用 jax.custom_vjp 定义仅后向模式的自定义规则#

虽然 jax.custom_jvp 足以控制前向模式和(通过 JAX 的自动转置)后向模式微分行为,但在某些情况下,我们可能希望直接控制 VJP 规则,例如在上文介绍的最后两个示例问题中。我们可以通过 jax.custom_vjp 来做到这一点。

from jax import custom_vjp
import jax.numpy as jnp

# f :: a -> b
@custom_vjp
def f(x):
  return jnp.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), jnp.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
  return (cos_x * y_bar,)

f.defvjp(f_fwd, f_bwd)
from jax import grad

print(f(3.))
print(grad(f)(3.))
0.14112
-0.9899925

简单来说,我们再次从一个原始函数 f 开始,该函数接受类型为 a 的输入并产生类型为 b 的输出。我们将其关联两个函数 f_fwdf_bwd,它们分别描述了如何执行后向模式自动微分的前向和后向传递。

函数 f_fwd 描述了前向传递,不仅包括原始计算,还包括为后向传递保存的值。其输入签名与原始函数 f 相同,因为它接受类型为 a 的原始输入。但作为输出,它产生一个对,其中第一个元素是原始输出 b,第二个元素是任何“残差”数据,类型为 c,用于存储以供后向传递使用。(这个第二个输出类似于 PyTorch 的 save_for_backward 机制。)

函数 f_bwd 描述了后向传递。它接受两个输入,第一个是 f_fwd 生成的类型为 c 的残差数据,第二个是对应于原始函数输出的类型为 CT b 的输出共切向量。它产生一个类型为 CT a 的输出,代表对应于原始函数输入的共切向量。特别是,f_bwd 的输出必须是与原始函数参数数量相同的序列(例如元组)。

所以多个参数的工作方式如下

from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405

使用关键字参数调用 jax.custom_vjp 函数,或使用默认参数编写 jax.custom_vjp 函数定义,都是允许的,只要它们可以根据标准库 inspect.signature 机制检索到的函数签名明确映射到位置参数。

jax.custom_jvp 一样,如果未应用微分,则不会调用由 f_fwdf_bwd 组成的自定义 VJP 规则。如果函数被评估,或被 jitvmap 或其他非微分转换转换,则只调用 f

@custom_vjp
def f(x):
  print("called f!")
  return jnp.sin(x)

def f_fwd(x):
  print("called f_fwd!")
  return f(x), jnp.cos(x)

def f_bwd(cos_x, y_bar):
  print("called f_bwd!")
  return (cos_x * y_bar,)

f.defvjp(f_fwd, f_bwd)
print(f(3.))
called f!
0.14112
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
-0.9899925
y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
print(f_vjp(1.))
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),)

不能对 jax.custom_vjp 函数使用前向模式自动微分,否则会引发错误。

from jax import jvp

try:
  jvp(f, (3.,), (1.,))
except TypeError as e:
  print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.

如果您想同时使用前向模式和后向模式,请改用 jax.custom_jvp

我们可以将 jax.custom_vjppdb 结合使用,在后向传播中插入调试器断点

import pdb

@custom_vjp
def debug(x):
  return x  # acts like identity

def debug_fwd(x):
  return x, x

def debug_bwd(x, g):
  pdb.set_trace()
  return g

debug.defvjp(debug_fwd, debug_bwd)
def foo(x):
  y = x ** 2
  y = debug(y)  # insert pdb in corresponding backward pass step
  return jnp.sin(y)
jax.grad(foo)(3.)

> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q

更多特性和细节#

处理 list / tuple / dict 容器(和其他 Pytrees)#

您应该期望像列表、元组、命名元组和字典这样的标准 Python 容器能够正常工作,以及这些容器的嵌套版本。通常,任何 Pytrees 都是允许的,只要它们的结构符合类型约束。

这是一个使用 jax.custom_jvp 的人为示例

from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])

@custom_jvp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}

@f.defjvp
def f_jvp(primals, tangents):
  pt, = primals
  pt_dot, =  tangents
  ans = f(pt)
  ans_dot = {'a': 2 * pt.x * pt_dot.x,
             'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
  return ans, ans_dot

def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0]
pt = Point(1., 2.)

print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))

以及一个使用 jax.custom_vjp 的类似人为示例

@custom_vjp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}

def f_fwd(pt):
  return f(pt), pt

def f_bwd(pt, g):
  a_bar, (b0_bar, b1_bar) = g['a'], g['b']
  x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
  y_bar = -jnp.sin(pt.y) * b1_bar
  return (Point(x_bar, y_bar),)

f.defvjp(f_fwd, f_bwd)

def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0]
pt = Point(1., 2.)

print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))

处理不可微分参数#

某些用例,如最后一个示例问题,需要将不可微分参数(如函数值参数)传递给具有自定义微分规则的函数,并且这些参数也要传递给规则本身。在 fixed_point 的情况下,函数参数 f 就是这样一个不可微分参数。类似的情况也出现在 jax.experimental.odeint 中。

jax.custom_jvp 配合 nondiff_argnums#

使用 jax.custom_jvp 的可选 nondiff_argnums 参数来指示此类参数。这是一个使用 jax.custom_jvp 的示例

from functools import partial

@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)

@app.defjvp
def app_jvp(f, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(x), 2. * x_dot
print(app(lambda x: x ** 3, 3.))
27.0
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0

注意这里的陷阱:无论这些参数在参数列表中的哪个位置,它们都会被放在相应 JVP 规则签名的 *开头*。再举一个例子

@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
  return f(g((x)))

@app2.defjvp
def app2_jvp(f, g, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(g(x)), 3. * x_dot
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0

jax.custom_vjp 配合 nondiff_argnums#

jax.custom_vjp 也有一个类似的选项,同样,约定是不可微分参数将作为第一个参数传递给 _bwd 规则,无论它们出现在原始函数签名中的哪个位置。 _fwd 规则的签名保持不变——它与原始函数签名相同。这是一个示例

@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)

def app_fwd(f, x):
  return f(x), x

def app_bwd(f, x, g):
  return (5 * g,)

app.defvjp(app_fwd, app_bwd)
print(app(lambda x: x ** 2, 4.))
16.0
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0

请参阅上面的 fixed_point 以获取另一个用法示例。

您无需为数值类型(例如整数 dtype)的数组值参数使用 nondiff_argnums。相反,nondiff_argnums 仅应用于不对应 JAX 类型(本质上不对应数组类型)的参数值,例如 Python 可调用对象或字符串。如果 JAX 检测到 nondiff_argnums 指定的参数包含 JAX Tracer,则会引发错误。上面的 clip_gradient 函数是一个不使用 nondiff_argnums 来处理整数 dtype 数组参数的好例子。