自动微分秘籍#

Open in Colab Open in Kaggle

JAX 有一个非常通用的自动微分系统。在本笔记本中,我们将浏览一系列整洁的自动微分想法,您可以为自己的工作挑选,从基础知识开始。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0)

梯度#

grad 开始#

您可以使用 grad 对函数求微分

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816

grad 接受一个函数并返回一个函数。如果你有一个 Python 函数 f,它计算数学函数 \(f\),那么 grad(f) 是一个 Python 函数,它计算数学函数 \(\nabla f\)。这意味着 grad(f)(x) 表示值 \(\nabla f(x)\)

由于 grad 操作的是函数,你可以将其应用于自身的输出,以进行任意次数的微分。

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405

让我们看看如何在线性逻辑回归模型中使用 grad 计算梯度。首先,进行设置。

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

使用 grad 函数的 argnums 参数来对函数相对于位置参数进行微分。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.43314594 -0.7354604  -1.2598921 ]
W_grad [-0.43314594 -0.7354604  -1.2598921 ]
b_grad -0.6900177
W_grad [-0.43314594 -0.7354604  -1.2598921 ]
b_grad -0.6900177

这个 grad API 直接对应于 Spivak 的经典著作 Calculus on Manifolds (1965) 中的优秀符号表示,也用于 Sussman 和 Wisdom 的 Structure and Interpretation of Classical Mechanics (2015) 和他们的 Functional Differential Geometry (2013)。这两本书都是开放获取的。特别请参阅 Functional Differential Geometry 的“序言”部分,以了解对该符号表示的辩护。

本质上,当使用 argnums 参数时,如果 f 是一个用于计算数学函数 \(f\) 的 Python 函数,那么 Python 表达式 grad(f, i) 的计算结果是一个用于计算 \(\partial_i f\) 的 Python 函数。

对嵌套列表、元组和字典进行微分#

对标准 Python 容器进行微分就可以直接使用,所以你可以随意使用元组、列表和字典(以及任意嵌套)。

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.43314594, -0.7354604 , -1.2598921 ], dtype=float32), 'b': Array(-0.6900177, dtype=float32)}

你可以注册你自己的容器类型,使其不仅可以与 grad 一起使用,还可以与所有的 JAX 转换(jitvmap 等)一起使用。

使用 value_and_grad 计算函数及其梯度#

另一个方便的函数是 value_and_grad,用于有效地计算函数的值及其梯度的值。

from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 2.9729187
loss value 2.9729187

与数值差分进行检查#

关于导数的一个好处是,可以使用有限差分直接进行检查。

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.6890297
b_grad_autodiff -0.6900177
W_dirderiv_numerical 1.3017654
W_dirderiv_autodiff 1.3006743

JAX 提供了一个简单的便利函数,它基本上做着相同的事情,但可以检查任何你喜欢的微分阶数。

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives

使用 grad-of-grad 计算 Hessian-向量积#

我们可以使用高阶 grad 做的一件事是构建一个 Hessian-向量积函数。(稍后我们将编写一个更高效的实现,它混合使用前向模式和反向模式,但这个将使用纯反向模式。)

Hessian-向量积函数在截断牛顿共轭梯度算法中用于最小化平滑凸函数,或用于研究神经网络训练目标(例如,1, 2, 3, 4)的曲率时非常有用。

对于具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\) (因此 Hessian 矩阵是对称的),在点 \(x \in \mathbb{R}^n\) 处的 Hessian 写为 \(\partial^2 f(x)\)。然后,Hessian-向量积函数能够计算

\(\qquad v \mapsto \partial^2 f(x) \cdot v\)

对于任何 \(v \in \mathbb{R}^n\)

关键是不实例化完整的 Hessian 矩阵:如果 \(n\) 很大,也许在神经网络的背景下是数百万或数十亿,那么这可能无法存储。

幸运的是,grad 已经为我们提供了一种编写高效 Hessian-向量积函数的方法。我们只需要使用恒等式

\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),

其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\)\(x\) 处的梯度与向量 \(v\) 点积。请注意,我们只对向量值参数的标量值函数进行微分,这正是我们知道 grad 有效的地方。

在 JAX 代码中,我们可以这样写

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

这个例子表明你可以自由地使用词法闭包,JAX 永远不会被打扰或困惑。

一旦我们了解如何计算密集的 Hessian 矩阵,我们将在接下来的几个单元格中检查此实现。我们还将编写一个更好的版本,它同时使用前向模式和反向模式。

使用 jacfwdjacrev 计算雅可比矩阵和 Hessian 矩阵#

你可以使用 jacfwdjacrev 函数计算完整的雅可比矩阵。

from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415  0.1091874   0.07506633]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447992]
 [ 0.00574409 -0.0193281   0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415  0.10918739  0.07506634]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447995]
 [ 0.00574409 -0.0193281   0.01078958]]

这两个函数计算相同的值(直到机器数值),但在实现上有所不同:jacfwd 使用前向模式自动微分,对于“高”雅可比矩阵(输出多于输入)更有效,而 jacrev 使用反向模式,对于“宽”雅可比矩阵(输入多于输出)更有效。对于接近正方形的矩阵,jacfwd 可能比 jacrev 具有优势。

你也可以将 jacfwdjacrev 与容器类型一起使用。

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)
Jacobian from W to logits is
[[ 0.05069415  0.10918739  0.07506634]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447995]
 [ 0.00574409 -0.0193281   0.01078958]]
Jacobian from b to logits is
[0.09748875 0.16102302 0.24190766 0.00776229]

有关前向模式和反向模式的更多详细信息,以及如何尽可能高效地实现 jacfwdjacrev,请继续阅读!

使用这两个函数的组合为我们提供了一种计算密集 Hessian 矩阵的方法。

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932  0.04434624  0.03048803]
  [ 0.04434623  0.09551499  0.06566654]
  [ 0.03048803  0.06566655  0.04514575]]

 [[-0.0743913   0.09129842 -0.01268033]
  [ 0.09129842 -0.11204806  0.01556223]
  [-0.01268034  0.01556223 -0.00216142]]

 [[ 0.01176856  0.00135791 -0.02942139]
  [ 0.00135791  0.00015668 -0.00339478]
  [-0.0294214  -0.00339478  0.07355348]]

 [[-0.00418412  0.014079   -0.00785936]
  [ 0.014079   -0.04737393  0.02644569]
  [-0.00785936  0.02644569 -0.01476286]]]

这种形状是有意义的:如果我们从函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,我们期望得到以下形状

  • \(f(x) \in \mathbb{R}^m\)\(f\)\(x\) 处的值,

  • \(\partial f(x) \in \mathbb{R}^{m \times n}\),在 \(x\) 处的雅可比矩阵,

  • \(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),在 \(x\) 处的 Hessian 矩阵,

等等。

要实现 hessian,我们可以使用 jacfwd(jacrev(f))jacrev(jacfwd(f)) 或这两个函数的任何其他组合。但是,前向模式优先于反向模式通常是最有效的。这是因为在内部雅可比计算中,我们经常对具有宽雅可比矩阵的函数进行微分(可能像损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\)),而在外部雅可比计算中,我们对具有正方形雅可比矩阵的函数进行微分(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这就是前向模式胜出的地方。

它是如何制作的:两个基础自动微分函数#

雅可比-向量积(JVP,又称前向模式自动微分)#

JAX 包含了前向和反向模式自动微分的高效和通用实现。我们熟悉的 grad 函数是基于反向模式构建的,但是为了解释两种模式的区别以及每种模式何时有用,我们需要一些数学背景知识。

数学中的 JVP#

在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)\(f\) 在输入点 \(x \in \mathbb{R}^n\) 处计算的雅可比矩阵,表示为 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的矩阵。

\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).

但是,我们也可以将 \(\partial f(x)\) 视为线性映射,它将 \(f\) 定义域在点 \(x\) 处的切空间(这只是 \(\mathbb{R}^n\) 的另一个副本)映射到 \(f\) 的值域在点 \(f(x)\) 处的切空间(\(\mathbb{R}^m\) 的副本)

\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).

此映射称为 \(f\)\(x\) 处的前推映射。雅可比矩阵只是此线性映射在标准基下的矩阵。

如果我们不提交到一个特定的输入点 \(x\),那么我们可以将函数 \(\partial f\) 视为首先接受一个输入点,并返回该输入点处的雅可比线性映射

\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).

特别是,我们可以解柯里化,使得给定输入点 \(x \in \mathbb{R}^n\) 和一个切向量 \(v \in \mathbb{R}^n\),我们可以得到 \(\mathbb{R}^m\) 中的一个输出切向量。我们将从 \((x, v)\) 对到输出切向量的映射称为雅可比-向量积,并将其写作

\(\qquad (x, v) \mapsto \partial f(x) v\)

JAX 代码中的 JVP#

回到 Python 代码中,JAX 的 jvp 函数模拟了这种转换。给定一个计算 \(f\) 的 Python 函数,JAX 的 jvp 是一种获取计算 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的 Python 函数的方法。

from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))

类似 Haskell 的类型签名来说,我们可以写成

jvp :: (a -> b) -> a -> T a -> (b, T b)

其中我们使用 T a 来表示 a 的切空间的类型。换句话说,jvp 接受一个类型为 a -> b 的函数,一个类型为 a 的值,以及一个类型为 T a 的切向量值作为参数。它返回一个由类型为 b 的值和类型为 T b 的输出切向量组成的对。

经过 jvp 转换的函数与原始函数非常相似,但它会伴随着类型为 a 的每个原始值,推送类型为 T a 的切线值。对于原始函数将应用的每个基本数值运算,jvp 转换后的函数会执行该基本运算的“JVP 规则”,该规则既会在原始值上评估基本运算,又会在这些原始值上应用基本运算的 JVP。

这种评估策略对计算复杂度有一些直接的影响:由于我们在计算过程中评估 JVP,我们不需要为以后存储任何内容,因此内存成本与计算深度无关。此外,jvp 转换后函数的 FLOP 成本大约是仅评估函数成本的 3 倍(评估原始函数的工作量为一个单位,例如 sin(x);线性化的工作量为一个单位,例如 cos(x);以及将线性化函数应用于向量的工作量为一个单位,例如 cos_x * v)。换句话说,对于一个固定的原始点 \(x\),我们可以以与评估 \(f\) 大致相同的边际成本来评估 \(v \mapsto \partial f(x) \cdot v\)

这种内存复杂度听起来很有吸引力!那么为什么我们在机器学习中很少看到前向模式呢?

要回答这个问题,首先考虑一下如何使用 JVP 构建完整的雅可比矩阵。如果我们对一个 one-hot 切向量应用 JVP,它会显示雅可比矩阵的一列,对应于我们输入的非零条目。因此,我们可以一次构建一列完整的雅可比矩阵,并且获取每一列的成本与一次函数评估大致相同。对于具有“高”雅可比矩阵的函数来说,这将是高效的,但对于“宽”雅可比矩阵来说,则效率低下。

如果您在机器学习中进行基于梯度的优化,您可能希望最小化从 \(\mathbb{R}^n\) 中的参数到 \(\mathbb{R}\) 中的标量损失值的损失函数。这意味着此函数的雅可比矩阵是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其等同于梯度向量 \(\nabla f(x) \in \mathbb{R}^n\)。一次构建一列该矩阵,并且每次调用都使用与评估原始函数相似的 FLOP,这看起来确实效率低下!特别是,对于训练神经网络,其中 \(f\) 是训练损失函数,并且 \(n\) 可能达到数百万甚至数十亿,这种方法将无法扩展。

为了更好地处理像这样的函数,我们只需要使用反向模式。

向量-雅可比积(VJP,又名反向模式自动微分)#

前向模式为我们提供了一个用于评估雅可比-向量积的函数,然后我们可以使用该函数一次构建一列雅可比矩阵,而反向模式是一种获取用于评估向量-雅可比积(等效于雅可比转置-向量积)的函数的方法,我们可以使用该函数一次构建一行雅可比矩阵。

数学中的 VJP#

让我们再次考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从我们 JVP 的符号开始,VJP 的符号非常简单

\(\qquad (x, v) \mapsto v \partial f(x)\),

其中 \(v\)\(f\)\(x\) 处的余切空间的一个元素(与另一个 \(\mathbb{R}^m\) 的副本同构)。当严格来说时,我们应该将 \(v\) 视为一个线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),当我们写 \(v \partial f(x)\) 时,我们指的是函数组合 \(v \circ \partial f(x)\),因为类型匹配,其中 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见情况下,我们可以将 \(v\)\(\mathbb{R}^m\) 中的向量等同起来,并且几乎可以互换地使用这两个概念,就像我们有时会在“列向量”和“行向量”之间切换而无需过多解释一样。

通过这种等同,我们可以将 VJP 的线性部分视为 JVP 的线性部分的转置(或伴随共轭)

\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).

对于给定点 \(x\),我们可以将签名写为

\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).

余切空间上的相应映射通常被称为 \(f\)\(x\) 处的拉回。我们目的的关键在于,它从看起来像 \(f\) 的输出的东西转变为看起来像 \(f\) 的输入的东西,正如我们可能从转置线性函数中期望的那样。

JAX 代码中的 VJP#

从数学回到 Python,JAX 函数 vjp 可以接受一个用于评估 \(f\) 的 Python 函数,并返回一个用于评估 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。

from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

类似 Haskell 的类型签名来说,我们可以写成

vjp :: (a -> b) -> a -> (b, CT b -> CT a)

其中我们使用 CT a 来表示 a 的余切空间的类型。换句话说,vjp 接受一个类型为 a -> b 的函数和一个类型为 a 的点作为参数,并返回一个由类型为 b 的值和类型为 CT b -> CT a 的线性映射组成的对。

这非常棒,因为它允许我们一次构建一行雅可比矩阵,并且评估 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOP 成本仅约为评估 \(f\) 成本的三倍。特别是,如果我们想要函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们可以在一次调用中完成。这就是为什么 grad 对于基于梯度的优化(即使对于数百万或数十亿参数的神经网络训练损失函数等目标)来说都是高效的。

但是,这也有一定的代价:虽然 FLOP 很友好,但内存会随着计算深度而扩展。此外,传统的实现比前向模式更复杂,尽管 JAX 有一些秘密武器(这将在未来的笔记本中讲述!)。

有关反向模式如何工作的更多信息,请参阅2017 年深度学习暑期学校的这个教程视频

使用 VJP 的向量值梯度#

如果您有兴趣获取向量值梯度(如 tf.gradients

from jax import vjp

def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
 [6. 6.]]

使用前向模式和反向模式的 Hessian-向量积#

在上一节中,我们仅使用反向模式(假设连续二阶导数)实现了一个 Hessian-向量积函数

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

这很高效,但我们可以通过结合前向模式和反向模式来做得更好,并节省一些内存。

在数学上,给定一个要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\),一个函数线性化的点 \(x \in \mathbb{R}^n\),以及一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量积函数是

\((x, v) \mapsto \partial^2 f(x) v\)

考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),它被定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。我们只需要它的 JVP,因为它会给我们

\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).

我们可以将其几乎直接翻译成代码

from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

更好的是,由于我们不必直接调用 jnp.dot,这个 hvp 函数可以处理任何形状的数组和任意容器类型(例如存储为嵌套列表/字典/元组的向量),甚至不依赖于 jax.numpy

下面是如何使用它的示例

def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True

你可能考虑编写此代码的另一种方法是使用反向覆盖前向

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

但这并不是很好,因为前向模式的开销比反向模式小,并且由于这里的外层微分算子必须微分比内层算子更大的计算,因此保持外层使用前向模式效果最佳。

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
5.36 ms ± 75.3 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
13.6 ms ± 9.57 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
19.3 ms ± 13.4 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
55.7 ms ± 2.9 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

组合 VJP、JVP 和 vmap#

雅可比矩阵和矩阵-雅可比积#

现在我们有了 jvpvjp 转换,它们提供了用于一次前推或后拉单个向量的函数,我们可以使用 JAX 的 vmap 转换来一次前推和后拉整个基。 特别是,我们可以使用它来编写快速的矩阵-雅可比和雅可比-矩阵积。

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
181 ms ± 775 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
5.91 ms ± 131 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1261/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
  return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
242 ms ± 401 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
2.92 ms ± 121 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

jacfwdjacrev 的实现#

现在我们已经看到了快速的雅可比矩阵和矩阵-雅可比积,不难猜到如何编写 jacfwdjacrev。 我们只是使用相同的技术来一次前推或后拉整个标准基(与单位矩阵同构)。

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once.
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

有趣的是,Autograd 做不到这一点。 我们在 Autograd 中的反向模式 jacobian实现必须使用外循环 map 一次后拉一个向量。 一次将一个向量推过计算的效率远低于使用 vmap 将它们全部批处理在一起。

Autograd 无法做到的另一件事是 jit。 有趣的是,无论你在要微分的函数中使用多少 Python 动态性,我们始终可以在计算的线性部分使用 jit。 例如

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)

复数和微分#

JAX 非常擅长复数和微分。为了支持全纯和非全纯微分,从 JVP 和 VJP 的角度思考会有所帮助。

考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\) 并将其与对应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\) 关联起来,

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j

def g(x, y):
  return (u(x, y), v(x, y))

也就是说,我们已经分解了 \(f(z) = u(x, y) + v(x, y) i\),其中 \(z = x + y i\),并将 \(\mathbb{C}\)\(\mathbb{R}^2\) 关联以获得 \(g\)

由于 \(g\) 仅涉及实数输入和输出,我们已经知道如何为其编写雅可比-向量积,例如,给定一个切向量 \((c, d) \in \mathbb{R}^2\),即

\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

为了获得应用于切向量 \(c + di \in \mathbb{C}\) 的原始函数 \(f\) 的 JVP,我们只需使用相同的定义并将结果标识为另一个复数,

\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).

这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数的 JVP 的定义!请注意,\(f\) 是否是全纯的并不重要:JVP 是明确的。

这是一个检查

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # tangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j

  # check jvp
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True

VJP 怎么样?我们做一些非常相似的事情:对于余切向量 \(c + di \in \mathbb{C}\),我们将 \(f\) 的 VJP 定义为

\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).

负号是怎么回事?它们只是为了处理复共轭,以及我们正在使用余向量的事实。

这是对 VJP 规则的检查

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # cotangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # for dtype control

  # check vjp
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)

gradjacfwdjacrev 这样的便利包装器呢?

对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下我们将 grad(f)(x) 定义为 vjp(f, x)[1](1.0),之所以有效,是因为将 VJP 应用于 1.0 值会揭示梯度(即雅可比,或导数)。 我们可以对 \(\mathbb{C} \to \mathbb{R}\) 函数执行相同的操作:我们仍然可以使用 1.0 作为余切向量,并且我们只会得到一个复数结果,总结完整的雅可比

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)

对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,雅可比矩阵具有 4 个实数值自由度(如上面的 2x2 雅可比矩阵所示),因此我们不能希望在复数内表示所有这些自由度。 但我们可以对全纯函数这样做! 全纯函数恰好是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,其特殊性质是其导数可以用单个复数表示。 (柯西-黎曼方程确保上述 2x2 雅可比矩阵在复平面上具有尺度和旋转矩阵的特殊形式,即在乘法下单个数的复数的作用。) 我们可以使用对 vjp 的单个调用和一个 1.0 的余向量来揭示该复数。

因为这仅适用于全纯函数,所以要使用此技巧,我们需要向 JAX 保证我们的函数是全纯的; 否则,当 grad 用于复数输出函数时,JAX 将引发错误

def f(z):
  return jnp.sin(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)

所有 holomorphic=True 承诺所做的只是在输出为复数时禁用错误。 当函数不是全纯时,我们仍然可以编写 holomorphic=True,但是我们得到的结果不会代表完整的雅可比矩阵。 相反,它将是函数的雅可比矩阵,其中我们只是丢弃了输出的虚部

def f(z):
  return jnp.conjugate(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)

这里 grad 的工作方式有一些有用的结果

  1. 我们可以在全纯 \(\mathbb{C} \to \mathbb{C}\) 函数上使用 grad

  2. 我们可以通过在 grad(f)(x) 的共轭方向上采取步骤,在 \(f : \mathbb{C} \to \mathbb{R}\) 函数上使用 grad 来优化,例如复参数 x 的实值损失函数。

  3. 如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,它恰好在内部使用了一些复数值运算(其中一些必须是非全纯的,例如卷积中使用的 FFT),那么 grad 仍然有效,我们得到的结果与仅使用实数值的实现所给出的结果相同。

在任何情况下,JVP 和 VJP 总是明确的。如果我们想要计算一个非全纯函数 \(\mathbb{C} \to \mathbb{C}\) 的完整雅可比矩阵,我们可以使用 JVP 或 VJP 来完成!

您应该期望复数在 JAX 中处处可用。这是对复矩阵的 Cholesky 分解进行微分的例子。

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])

def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)

grad(f, holomorphic=True)(A)
Array([[-0.7534186  +0.j       , -3.0509028 -10.940544j ,
         5.9896846  +3.5423026j],
       [-3.0509028 +10.940544j , -8.904491   +0.j       ,
        -5.1351523  -6.559373j ],
       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,
         0.01320427 +0.j       ]], dtype=complex64)

更高级的自动微分#

在本笔记本中,我们逐步介绍了 JAX 中自动微分的一些简单应用,然后逐渐介绍了更复杂的应用。我们希望您现在觉得在 JAX 中求导既简单又强大。

还有许多其他的自动微分技巧和功能。我们没有涵盖的主题,但希望在“高级自动微分食谱”中涵盖的主题包括:

  • 高斯-牛顿向量积,线性化一次

  • 自定义 VJP 和 JVP

  • 在定点处的高效导数

  • 使用随机 Hessian-向量积估计 Hessian 的迹。

  • 仅使用反向模式自动微分的前向模式自动微分。

  • 针对自定义数据类型求导。

  • 检查点(用于高效反向模式的二项式检查点,而不是模型快照)。

  • 使用雅可比预积累优化 VJP。