自动微分指南#

Open in Colab Open in Kaggle

JAX 拥有一个相当通用的自动微分系统。在本笔记本中,我们将从基础开始,深入探讨许多可供您在自己的工作中选择的巧妙自动微分思想。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0)

梯度#

grad 开始#

您可以使用 grad 对函数进行微分

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816

grad 接受一个函数并返回一个函数。如果您有一个 Python 函数 f 用于评估数学函数 \(f\),那么 grad(f) 是一个用于评估数学函数 \(\nabla f\) 的 Python 函数。这意味着 grad(f)(x) 表示 \(\nabla f(x)\) 的值。

由于 grad 作用于函数,您可以将其应用于其自身的输出,以任意次地进行微分

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405

让我们看看在线性逻辑回归模型中如何使用 grad 函数计算梯度。首先,是设置部分

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

使用 grad 函数及其 argnums 参数,对函数的指定位置参数求微分。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.433146  -0.7354605 -1.2598922]
W_grad [-0.433146  -0.7354605 -1.2598922]
b_grad -0.6900178
W_grad [-0.433146  -0.7354605 -1.2598922]
b_grad -0.6900178

grad API 与 Spivak 经典著作《流形上的微积分》(Calculus on Manifolds) (1965) 中卓越的记法直接对应,该记法也用于 Sussman 和 Wisdom 的《经典力学的结构与解释》(Structure and Interpretation of Classical Mechanics) (2015) 以及他们的《泛函微分几何》(Functional Differential Geometry) (2013)。这两本书均可公开获取。特别是《泛函微分几何》的“序言”部分,为这种记法进行了辩护。

实质上,当使用 argnums 参数时,如果 f 是用于评估数学函数 \(f\) 的 Python 函数,则 Python 表达式 grad(f, i) 的评估结果是用于评估 \(\partial_i f\) 的 Python 函数。

对嵌套列表、元组和字典求微分#

对标准 Python 容器求微分是可行的,因此您可以随意使用元组、列表和字典(以及任意嵌套)。

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32), 'b': Array(-0.6900178, dtype=float32)}

您可以注册您自己的容器类型,使其不仅能与 grad 配合使用,还能与所有 JAX 转换(jitvmap 等)配合使用。

使用 value_and_grad 评估函数及其梯度#

另一个方便的函数是 value_and_grad,它能高效地计算函数的值及其梯度值

from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 2.9729187
loss value 2.9729187

与数值差分进行校验#

导数的一个优点是它们可以用有限差分法轻松验证

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.6890297
b_grad_autodiff -0.6900178
W_dirderiv_numerical 1.3017654
W_dirderiv_autodiff 1.3006744

JAX 提供了一个简单的便捷函数,其功能基本相同,但可以检查您喜欢的任意阶微分。

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives

使用 gradgrad 计算 Hessian-向量积#

使用高阶 grad 我们可以构建一个 Hessian-向量积函数。(稍后我们将编写一个更高效的实现,它混合了前向模式和反向模式,但这个实现将使用纯反向模式。)

Hessian-向量积函数在截断牛顿共轭梯度算法中可能很有用,用于最小化平滑凸函数,或用于研究神经网络训练目标函数的曲率(例如1234)。

对于一个具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\)(因此 Hessian 矩阵是对称的),在点 \(x \in \mathbb{R}^n\) 处的 Hessian 矩阵记作 \(\partial^2 f(x)\)。Hessian-向量积函数则能够评估

\(\qquad v \mapsto \partial^2 f(x) \cdot v\)

对于任何 \(v \in \mathbb{R}^n\)

这里的诀窍是不实例化完整的 Hessian 矩阵:如果 \(n\) 很大,在神经网络的背景下可能达到数百万或数十亿,那么存储它可能是不可能的。

幸运的是,grad 已经为我们提供了一种编写高效 Hessian-向量积函数的方法。我们只需要使用恒等式

\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),

其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\)\(x\) 处的梯度与向量 \(v\) 进行点积。请注意,我们只对向量值参数的标量值函数进行微分,这正是我们知道 grad 效率高的地方。

在 JAX 代码中,我们可以这样写

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

这个例子表明您可以自由使用词法闭包,JAX 永远不会受到干扰或混淆。

我们将在几个单元格后验证此实现,一旦我们了解如何计算稠密 Hessian 矩阵。我们还将编写一个更好版本,它同时使用前向模式和反向模式。

使用 jacfwdjacrev 计算 Jacobian 和 Hessian#

您可以使用 jacfwdjacrev 函数计算完整的 Jacobian 矩阵。

from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415  0.1091874   0.07506633]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447992]
 [ 0.00574409 -0.0193281   0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415  0.10918741  0.07506634]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447995]
 [ 0.00574409 -0.0193281   0.01078958]]

这两个函数计算相同的值(在机器精度范围内),但它们的实现方式不同:jacfwd 使用前向模式自动微分,对于“高瘦”Jacobian 矩阵(输出多于输入)更高效;而 jacrev 使用反向模式,对于“宽胖”Jacobian 矩阵(输入多于输出)更高效。对于接近方阵的矩阵,jacfwd 可能比 jacrev 略有优势。

您也可以将 jacfwdjacrev 与容器类型一起使用

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)
Jacobian from W to logits is
[[ 0.05069415  0.10918741  0.07506634]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447995]
 [ 0.00574409 -0.0193281   0.01078958]]
Jacobian from b to logits is
[0.09748876 0.16102302 0.24190766 0.00776229]

有关前向模式和反向模式的更多详细信息,以及如何尽可能高效地实现 jacfwdjacrev,请继续阅读!

使用这两个函数的组合,我们可以计算稠密的 Hessian 矩阵

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932  0.04434624  0.03048803]
  [ 0.04434623  0.09551499  0.06566654]
  [ 0.03048803  0.06566655  0.04514575]]

 [[-0.0743913   0.09129842 -0.01268033]
  [ 0.09129842 -0.11204806  0.01556223]
  [-0.01268034  0.01556223 -0.00216142]]

 [[ 0.01176856  0.00135791 -0.02942139]
  [ 0.00135791  0.00015668 -0.00339478]
  [-0.0294214  -0.00339478  0.07355348]]

 [[-0.00418412  0.014079   -0.00785936]
  [ 0.014079   -0.04737393  0.02644569]
  [-0.00785936  0.02644569 -0.01476286]]]

这种形状是合理的:如果我们从一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,我们期望得到如下形状:

  • \(f(x) \in \mathbb{R}^m\)\(f\)\(x\) 处的值,

  • \(\partial f(x) \in \mathbb{R}^{m \times n}\),在 \(x\) 处的 Jacobian 矩阵,

  • \(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),在 \(x\) 处的 Hessian,

等等。

为了实现 hessian,我们可以使用 jacfwd(jacrev(f))jacrev(jacfwd(f)) 或两者的任何其他组合。但前向-后向的组合通常效率最高。这是因为在内部 Jacobian 计算中,我们通常是对一个具有宽 Jacobian 的函数(可能像损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\))进行微分,而在外部 Jacobian 计算中,我们是对一个具有方阵 Jacobian 的函数(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\))进行微分,而前向模式在这种情况下表现最佳。

如何实现:两个基础自动微分函数#

Jacobian-向量积(JVPs,即前向模式自动微分)#

JAX 包含了前向模式和反向模式自动微分的高效通用实现。我们熟悉的 grad 函数是基于反向模式构建的,但要解释这两种模式的区别以及何时何地每种模式有用,我们需要一些数学背景知识。

数学中的 JVPs#

在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)\(f\) 在输入点 \(x \in \mathbb{R}^n\) 处评估的 Jacobian 矩阵,记作 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的一个矩阵

\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).

但我们也可以将 \(\partial f(x)\) 视为一个线性映射,它将 \(f\) 定义域在 \(x\) 点处的切空间(也就是 \(\mathbb{R}^n\) 的另一个副本)映射到 \(f\) 余定义域在 \(f(x)\) 点处的切空间(\(\mathbb{R}^m\) 的一个副本)

\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).

这个映射称为 \(f\)\(x\) 处的推前映射。Jacobian 矩阵只是这个线性映射在标准基下的矩阵。

如果我们不指定一个具体的输入点 \(x\),那么我们可以把函数 \(\partial f\) 看作是先接受一个输入点,然后返回该输入点处的 Jacobian 线性映射

\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).

特别是,我们可以将参数分解,使得给定输入点 \(x \in \mathbb{R}^n\) 和一个切向量 \(v \in \mathbb{R}^n\),我们得到一个在 \(\mathbb{R}^m\) 中的输出切向量。我们将从 \((x, v)\) 对到输出切向量的映射称为 Jacobian-向量积,并将其写作

\(\qquad (x, v) \mapsto \partial f(x) v\)

JAX 代码中的 JVPs#

回到 Python 代码中,JAX 的 jvp 函数模拟了这种转换。给定一个评估 \(f\) 的 Python 函数,JAX 的 jvp 是一种获取评估 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的 Python 函数的方法。

from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))

Haskell 风格的类型签名,我们可以写成

jvp :: (a -> b) -> a -> T a -> (b, T b)

其中我们使用 T a 来表示 a 的切线空间的类型。换句话说,jvp 接受类型为 a -> b 的函数、类型为 a 的值和类型为 T a 的切向量值作为参数。它返回一个对,包含类型为 b 的值和类型为 T b 的输出切向量。

经过 jvp 转换的函数评估方式与原始函数非常相似,但它将类型为 a 的每个原始值与类型为 T a 的切线值配对并一起推算。对于原始函数将应用的每个原始数值操作,经过 jvp 转换的函数会执行该原始操作的“JVP 规则”,该规则既会在原始值上评估该原始操作,又会在这些原始值上应用该原始操作的 JVP。

这种评估策略对计算复杂性有一些直接影响:由于我们边计算边评估 JVP,因此无需存储任何内容以备后用,因此内存成本与计算深度无关。此外,jvp 转换函数的 FLOP 成本约为仅评估函数成本的 3 倍(评估原始函数的工作量为 1 个单位,例如 sin(x);线性化的工作量为 1 个单位,例如 cos(x);以及将线性化函数应用于向量的工作量为 1 个单位,例如 cos_x * v)。换句话说,对于一个固定的原始点 \(x\),我们可以用与评估 \(f\) 相同的边际成本来评估 \(v \mapsto \partial f(x) \cdot v\)

这种内存复杂度听起来很有说服力!那为什么在机器学习中我们不经常看到前向模式呢?

要回答这个问题,首先思考如何使用 JVP 构建完整的 Jacobian 矩阵。如果我们对一个独热切向量应用 JVP,它将揭示 Jacobian 矩阵的一列,对应于我们输入的非零项。因此,我们可以一次构建 Jacobian 的一列,并且获取每列的成本与一次函数评估的成本大致相同。对于具有“高瘦”Jacobian 的函数,这将是高效的,但对于“宽胖”Jacobian 的函数,则效率低下。

如果您在机器学习中进行基于梯度的优化,您可能希望最小化一个将 \(\mathbb{R}^n\) 中的参数映射到标量损失值 \(\mathbb{R}\) 的损失函数。这意味着这个函数的 Jacobian 是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其与梯度向量 \(\nabla f(x) \in \mathbb{R}^n\) 等同。一次一列地构建这个矩阵,每次调用都花费与评估原始函数相似的 FLOPs,这看起来确实效率低下!特别是对于训练神经网络,其中 \(f\) 是训练损失函数,而 \(n\) 可以达到数百万甚至数十亿,这种方法根本无法扩展。

为了更好地处理此类函数,我们只需要使用反向模式。

向量-Jacobian 积(VJPs,即反向模式自动微分)#

前向模式返回一个用于评估 Jacobian-向量积的函数,我们可以用它一次构建 Jacobian 矩阵的一列;而反向模式则返回一个用于评估向量-Jacobian 积(等同于 Jacobian 转置-向量积)的函数,我们可以用它一次构建 Jacobian 矩阵的一行。

数学中的 VJPs#

让我们再次考虑函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从我们对 JVP 的记法开始,VJP 的记法非常简单

\(\qquad (x, v) \mapsto v \partial f(x)\),

其中 \(v\)\(f\)\(x\) 处的余切空间的一个元素(同构于 \(\mathbb{R}^m\) 的另一个副本)。严格来说,我们应该将 \(v\) 视为一个线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),当我们写 \(v \partial f(x)\) 时,我们指的是函数组合 \(v \circ \partial f(x)\),其中类型匹配是因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见情况下,我们可以将 \(v\) 视为 \(\mathbb{R}^m\) 中的向量,并几乎可以互换地使用两者,就像我们有时会不加评论地在“列向量”和“行向量”之间切换一样。

通过这种识别,我们可以将 VJP 的线性部分视为 JVP 线性部分的转置(或伴随共轭)

\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).

对于给定点 \(x\),我们可以将签名写为

\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).

在余切空间上对应的映射通常称为 \(f\)\(x\) 处的拉回。对我们而言,关键是它从看起来像 \(f\) 输出的东西转变为看起来像 \(f\) 输入的东西,就像我们可能从转置线性函数中预期的那样。

JAX 代码中的 VJPs#

从数学回到 Python,JAX 函数 vjp 可以接受一个用于评估 \(f\) 的 Python 函数,并返回一个用于评估 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。

from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

Haskell 风格的类型签名,我们可以写成

vjp :: (a -> b) -> a -> (b, CT b -> CT a)

其中我们使用 CT a 来表示 a 的余切空间的类型。换句话说,vjp 接受一个类型为 a -> b 的函数和一个类型为 a 的点作为参数,并返回一个对,包含一个类型为 b 的值和一个类型为 CT b -> CT a 的线性映射。

这很棒,因为它允许我们一次构建 Jacobian 矩阵的一行,并且评估 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOP 成本仅为评估 \(f\) 成本的三倍左右。特别是,如果我们想要一个函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们只需一次调用即可完成。这就是 grad 对基于梯度的优化效率高的原因,即使对于神经网络训练损失函数等具有数百万或数十亿参数的目标函数也是如此。

然而,这也有代价:尽管 FLOPs 很友好,但内存会随着计算深度而扩展。此外,实现传统上比前向模式更复杂,尽管 JAX 有一些技巧(那是未来笔记本的故事!)。

有关反向模式如何工作的更多信息,请参阅2017 年深度学习暑期学校的这篇教程视频

使用 VJPs 计算向量值梯度#

如果您对计算向量值梯度感兴趣(例如 tf.gradients

from jax import vjp

def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
 [6. 6.]]

同时使用前向模式和反向模式计算 Hessian-向量积#

在上一节中,我们仅使用反向模式实现了一个 Hessian-向量积函数(假设连续二阶导数)

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

这很高效,但我们可以做得更好,通过同时使用前向模式和反向模式来节省一些内存。

数学上,给定一个要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\),一个对函数进行线性化的点 \(x \in \mathbb{R}^n\),以及一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量积函数是

\((x, v) \mapsto \partial^2 f(x) v\)

考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。我们只需要它的 JVP,因为那会给我们

\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).

我们可以将其几乎直接翻译成代码

from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

更好的是,由于我们不必直接调用 jnp.dot,这个 hvp 函数适用于任何形状的数组和任意容器类型(例如存储为嵌套列表/字典/元组的向量),甚至不依赖于 jax.numpy

下面是一个如何使用它的示例

def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True

另一种你可以考虑的写法是使用反向-前向模式

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

不过,这并不完全好,因为前向模式的开销比反向模式小,而且由于这里外部的微分算子必须对一个比内部算子更大的计算进行微分,所以保持前向模式在外部效果最好。

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
4.46 ms ± 90.1 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
11.9 ms ± 7.28 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
15 ms ± 6.26 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
55.6 ms ± 2.19 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

组合 VJPs、JVPs 和 vmap#

Jacobian-矩阵和矩阵-Jacobian 积#

既然我们已经了解了快速 Jacobian-矩阵和矩阵-Jacobian 积,那么猜测如何编写 jacfwdjacrev 就不难了。我们只需使用相同的技术来一次性推前或拉回整个标准基(同构于单位矩阵)。

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
62.4 ms ± 315 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
4.53 ms ± 51.2 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_2026/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
  return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
163 ms ± 290 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
2.28 ms ± 47.6 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

jacfwdjacrev 的实现#

既然我们已经看到了快速的 Jacobian-矩阵乘积和矩阵-Jacobian 乘积,那么要猜测如何编写 jacfwdjacrev 就不难了。我们只需使用相同的技术,一次性推前或拉回整个标准基(同构于单位矩阵)。

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once.
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

有趣的是,Autograd 无法做到这一点。我们在 Autograd 中反向模式 jacobian 的实现必须通过一个外层循环 map 一次拉回一个向量。一次将一个向量推过计算比使用 vmap 将它们全部批量处理效率低得多。

Autograd 无法做到的另一件事是 jit。有趣的是,无论您在要微分的函数中使用多少 Python 动态特性,我们总是可以在计算的线性部分使用 jit。例如

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)

复数与微分#

JAX 在复数和微分方面表现出色。为了支持全纯和非全纯微分,最好从 JVP 和 VJP 的角度来思考。

考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\),并将其与对应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\) 等同,

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j

def g(x, y):
  return (u(x, y), v(x, y))

也就是说,我们已经将 \(f(z) = u(x, y) + v(x, y) i\) 分解,其中 \(z = x + y i\),并将 \(\mathbb{C}\)\(\mathbb{R}^2\) 等同以得到 \(g\)

由于 \(g\) 只涉及实数输入和输出,我们已经知道如何为其编写 Jacobian-向量积,例如给定一个切向量 \((c, d) \in \mathbb{R}^2\),即

\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

为了得到应用于切向量 \(c + di \in \mathbb{C}\) 的原始函数 \(f\) 的 JVP,我们只需使用相同的定义并将结果识别为另一个复数,

\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数 JVP 的定义!请注意,\(f\) 是否全纯并不重要:JVP 是明确的。

这是一个检查

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # tangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j

  # check jvp
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True

VJP 呢?我们做的事情很相似:对于余切向量 \(c + di \in \mathbb{C}\),我们将 \(f\) 的 VJP 定义为

\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).

负号是做什么用的?它们只是为了处理复共轭,以及我们正在处理余向量的事实。

以下是 VJP 规则的检查

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # cotangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # for dtype control

  # check vjp
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)

那么像 gradjacfwdjacrev 这样的便捷封装器呢?

对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下我们定义 grad(f)(x)vjp(f, x)[1](1.0),这之所以有效是因为将 VJP 应用于 1.0 值会揭示梯度(即 Jacobian,或导数)。我们也可以对 \(\mathbb{C} \to \mathbb{R}\) 函数做同样的事情:我们仍然可以使用 1.0 作为余切向量,并且我们只会得到一个总结完整 Jacobian 的复数结果

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)

对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,Jacobian 具有 4 个实值自由度(如上述 2x2 Jacobian 矩阵),因此我们无法将其全部表示为一个复数。但对于全纯函数,我们可以!全纯函数恰好是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,其特殊性质是其导数可以表示为单个复数。(柯西-黎曼方程确保上述 2x2 Jacobian 具有复平面中缩放和旋转矩阵的特殊形式,即单个复数乘法的作用。)我们可以通过一次调用 vjp 和一个 1.0 的余向量来揭示这个复数。

由于这仅适用于全纯函数,因此要使用此技巧,我们需要向 JAX 承诺我们的函数是全纯的;否则,当 grad 用于复数输出函数时,JAX 将引发错误

def f(z):
  return jnp.sin(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)

所有 holomorphic=True 承诺只是在输出为复数值时禁用错误。当函数不是全纯时,我们仍然可以编写 holomorphic=True,但我们得到的结果不会表示完整的 Jacobian。相反,它将是函数在丢弃输出虚部后的 Jacobian

def f(z):
  return jnp.conjugate(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)

这里关于 grad 的工作方式有一些有用的结果

  1. 我们可以在全纯 \(\mathbb{C} \to \mathbb{C}\) 函数上使用 grad

  2. 我们可以使用 grad 来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复数参数 x 的实值损失函数,通过沿着 grad(f)(x) 的共轭方向进行步进。

  3. 如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,它只是碰巧在内部使用了一些复数值操作(其中一些必须是非全纯的,例如卷积中使用的 FFT),那么 grad 仍然有效,并且我们得到的结果与仅使用实数值的实现所得到的结果相同。

无论如何,JVPs 和 VJPs 总是明确的。如果我们要计算非全纯 \(\mathbb{C} \to \mathbb{C}\) 函数的完整 Jacobian 矩阵,我们可以使用 JVPs 或 VJPs 来完成!

您应该期待复数在 JAX 的任何地方都能正常工作。以下是对复数矩阵进行 Cholesky 分解的微分

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])

def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)

grad(f, holomorphic=True)(A)
Array([[-0.75341904 +0.j       , -3.0509028 -10.940545j ,
         5.9896846  +3.5423026j],
       [-3.0509028 +10.940545j , -8.904491   +0.j       ,
        -5.1351523  -6.559373j ],
       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,
         0.01320427 +0.j       ]], dtype=complex64)

更高级的自动微分#

在本笔记本中,我们从一些简单到逐步复杂的 JAX 自动微分应用进行了探讨。我们希望您现在觉得在 JAX 中进行微分既简单又强大。

自动微分领域还有许多其他技巧和功能。我们未涵盖但希望在“高级自动微分指南”中涉及的主题包括:

  • 高斯-牛顿向量积,一次线性化

  • 自定义 VJP 和 JVP

  • 固定点处的有效导数

  • 使用随机 Hessian-向量积估计 Hessian 的迹。

  • 仅使用反向模式自动微分实现前向模式自动微分。

  • 对自定义数据类型进行微分。

  • 检查点(用于高效反向模式的二项式检查点,而非模型快照)。

  • 使用 Jacobian 预累积优化 VJP。