Autodiff Cookbook#
JAX 拥有相当通用的自动微分系统。在本笔记本中,我们将介绍大量简洁的自动微分思想,您可以根据自己的工作随意挑选,从基础知识开始。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
梯度#
从 grad
开始#
您可以使用 grad
对函数求微分
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
grad
接受一个函数并返回一个函数。如果您有一个 Python 函数 f
,用于评估数学函数 \(f\),那么 grad(f)
是一个 Python 函数,用于评估数学函数 \(\nabla f\)。这意味着 grad(f)(x)
表示值 \(\nabla f(x)\)。
由于 grad
对函数进行操作,您可以将其应用于自身的输出,以进行任意次数的微分
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
让我们看看在线性逻辑回归模型中使用 grad
计算梯度。首先,进行设置
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
使用 grad
函数及其 argnums
参数,以针对位置参数对函数进行微分。
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.43314594 -0.7354604 -1.2598921 ]
W_grad [-0.43314594 -0.7354604 -1.2598921 ]
b_grad -0.69001764
W_grad [-0.43314594 -0.7354604 -1.2598921 ]
b_grad -0.69001764
此 grad
API 与 Spivak 经典著作 Calculus on Manifolds (1965) 中的优秀符号直接对应,也用于 Sussman 和 Wisdom 的 Structure and Interpretation of Classical Mechanics (2015) 以及他们的 Functional Differential Geometry (2013)。这两本书都是开放获取的。特别是请参阅 Functional Differential Geometry 的“序言”部分,以了解对此符号的辩护。
本质上,当使用 argnums
参数时,如果 f
是用于评估数学函数 \(f\) 的 Python 函数,则 Python 表达式 grad(f, i)
评估为用于评估 \(\partial_i f\) 的 Python 函数。
对嵌套列表、元组和字典求微分#
对标准 Python 容器求微分可以直接使用,因此可以随意使用元组、列表和字典(以及任意嵌套)。
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.43314594, -0.7354604 , -1.2598921 ], dtype=float32), 'b': Array(-0.69001764, dtype=float32)}
您可以注册您自己的容器类型,使其不仅可以与 grad
一起使用,还可以与所有 JAX 转换 (jit
、vmap
等) 一起使用。
使用 value_and_grad
评估函数及其梯度#
另一个方便的函数是 value_and_grad
,用于高效地计算函数值及其梯度值
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 2.9729187
loss value 2.9729187
对照数值差异进行检查#
关于导数的一个优点是,它们可以使用有限差分法直接进行检查
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.6890297
b_grad_autodiff -0.69001764
W_dirderiv_numerical 1.3041496
W_dirderiv_autodiff 1.3006743
JAX 提供了一个简单的便捷函数,它执行基本相同的操作,但可以检查您喜欢的任何微分阶数
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
使用 grad
的 grad
计算 Hessian 向量积#
我们可以使用高阶 grad
做的一件事是构建 Hessian 向量积函数。(稍后我们将编写更高效的实现,它混合了正向模式和反向模式,但这个实现将使用纯反向模式。)
Hessian 向量积函数在 截断牛顿共轭梯度算法中用于最小化平滑凸函数,或用于研究神经网络训练目标的曲率(例如 1, 2, 3, 4)。
对于具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\)(因此 Hessian 矩阵是对称的),点 \(x \in \mathbb{R}^n\) 处的 Hessian 矩阵写为 \(\partial^2 f(x)\)。然后,Hessian 向量积函数能够评估
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
对于任何 \(v \in \mathbb{R}^n\)。
诀窍是不实例化完整的 Hessian 矩阵:如果 \(n\) 很大,在神经网络的上下文中可能达到数百万或数十亿,那么可能无法存储它。
幸运的是,grad
已经为我们提供了一种编写高效 Hessian 向量积函数的方法。我们只需要使用恒等式
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(x\) 处的 \(f\) 的梯度与向量 \(v\) 进行点积运算。请注意,我们只对向量值参数的标量值函数进行微分,这正是我们知道 grad
高效的地方。
在 JAX 代码中,我们可以这样写
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
此示例表明您可以自由使用词法闭包,而 JAX 永远不会感到困惑或迷惑。
我们将在稍后的单元格中检查此实现,一旦我们了解如何计算稠密 Hessian 矩阵。我们还将编写一个更好的版本,它同时使用正向模式和反向模式。
使用 jacfwd
和 jacrev
计算雅可比矩阵和 Hessian 矩阵#
您可以使用 jacfwd
和 jacrev
函数计算完整的雅可比矩阵
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415 0.1091874 0.07506633]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447992]
[ 0.00574409 -0.0193281 0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415 0.10918739 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
这两个函数计算相同的值(达到机器数值精度),但它们的实现有所不同:jacfwd
使用正向模式自动微分,对于“高”雅可比矩阵(输出多于输入)更有效,而 jacrev
使用反向模式,对于“宽”雅可比矩阵(输入多于输出)更有效。对于接近正方形的矩阵,jacfwd
可能比 jacrev
更有优势。
您还可以将 jacfwd
和 jacrev
与容器类型一起使用
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05069415 0.10918739 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
Jacobian from b to logits is
[0.09748875 0.16102302 0.24190766 0.00776229]
有关正向模式和反向模式的更多详细信息,以及如何尽可能高效地实现 jacfwd
和 jacrev
,请继续阅读!
使用这两个函数的组合为我们提供了一种计算稠密 Hessian 矩阵的方法
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932 0.04434624 0.03048803]
[ 0.04434623 0.09551499 0.06566654]
[ 0.03048803 0.06566655 0.04514575]]
[[-0.0743913 0.09129842 -0.01268033]
[ 0.09129842 -0.11204806 0.01556223]
[-0.01268034 0.01556223 -0.00216142]]
[[ 0.01176856 0.00135791 -0.02942139]
[ 0.00135791 0.00015668 -0.00339478]
[-0.0294214 -0.00339478 0.07355348]]
[[-0.00418412 0.014079 -0.00785936]
[ 0.014079 -0.04737393 0.02644569]
[-0.00785936 0.02644569 -0.01476286]]]
这种形状是有道理的:如果我们从函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,我们期望得到以下形状
\(f(x) \in \mathbb{R}^m\),\(f\) 在 \(x\) 处的值,
\(\partial f(x) \in \mathbb{R}^{m \times n}\),\(x\) 处的雅可比矩阵,
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),\(x\) 处的 Hessian 矩阵,
等等。
为了实现 hessian
,我们可以使用 jacfwd(jacrev(f))
或 jacrev(jacfwd(f))
或这两个函数的任何其他组合。但正向模式 over 反向模式通常是最有效的。这是因为在内部雅可比矩阵计算中,我们通常对宽雅可比矩阵的函数(可能像损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\))进行微分,而在外部雅可比矩阵计算中,我们对具有正方形雅可比矩阵的函数进行微分(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),而这正是正向模式胜出的地方。
原理:两个基础的自动微分函数#
雅可比-向量积 (JVPs,又名正向模式自动微分)#
JAX 包括正向模式和反向模式自动微分的高效通用实现。熟悉的 grad
函数基于反向模式构建,但为了解释两种模式之间的区别以及每种模式何时有用,我们需要一些数学背景知识。
JVPs 的数学原理#
在数学上,给定函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\),在输入点 \(x \in \mathbb{R}^n\) 处评估的 \(f\) 的雅可比矩阵,表示为 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的矩阵
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
但我们也可以将 \(\partial f(x)\) 视为线性映射,它将 \(f\) 的域在点 \(x\) 处的切空间(它只是 \(\mathbb{R}^n\) 的另一个副本)映射到 \(f\) 的共域在点 \(f(x)\) 处的切空间(\(\mathbb{R}^m\) 的副本)
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
此映射称为 \(f\) 在 \(x\) 处的前推映射。雅可比矩阵只是此线性映射在标准基下的矩阵。
如果我们不确定一个特定的输入点 \(x\),那么我们可以将函数 \(\partial f\) 视为首先获取一个输入点,然后返回该输入点处的雅可比线性映射
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
特别是,我们可以解开事情,以便给定输入点 \(x \in \mathbb{R}^n\) 和切向量 \(v \in \mathbb{R}^n\),我们得到 \(\mathbb{R}^m\) 中的输出切向量。我们将从 \((x, v)\) 对到输出切向量的映射称为雅可比-向量积,并将其写为
\(\qquad (x, v) \mapsto \partial f(x) v\)
JAX 代码中的 JVPs#
回到 Python 代码中,JAX 的 jvp
函数模拟了此转换。给定一个评估 \(f\) 的 Python 函数,JAX 的 jvp
是一种获取 Python 函数来评估 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的方法。
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
用 类似 Haskell 的类型签名,我们可以写成
jvp :: (a -> b) -> a -> T a -> (b, T b)
其中我们使用 T a
来表示 a
的切空间的类型。换句话说,jvp
接受类型为 a -> b
的函数、类型为 a
的值和类型为 T a
的切向量值作为参数。它返回一个由类型为 b
的值和类型为 T b
的输出切向量组成的对。
jvp
转换的函数与原始函数的评估方式非常相似,但它与类型为 a
的每个原始值配对,它会沿着类型为 T a
的切线值推送。对于原始函数将应用的每个原始数值运算,jvp
转换的函数都会执行该原始函数的“JVP 规则”,该规则既评估原始值上的原始函数,又在这些原始值上应用原始函数的 JVP。
该评估策略对计算复杂度有一些直接的影响:由于我们在进行过程中评估 JVP,因此我们不需要为以后存储任何内容,因此内存成本与计算深度无关。此外,jvp
转换的函数的 FLOP 成本约为仅评估函数成本的 3 倍(一个工作单元用于评估原始函数,例如 sin(x)
;一个单元用于线性化,如 cos(x)
;一个单元用于将线性化函数应用于向量,如 cos_x * v
)。换句话说,对于固定的原始点 \(x\),我们可以评估 \(v \mapsto \partial f(x) \cdot v\),其边际成本与评估 \(f\) 的成本大致相同。
这种内存复杂度听起来非常引人注目!那么为什么我们在机器学习中不经常看到正向模式呢?
为了回答这个问题,首先考虑一下如何使用 JVP 构建完整的雅可比矩阵。如果我们对单热切向量应用 JVP,它会显示雅可比矩阵的一列,对应于我们输入的非零条目。因此,我们可以一次构建一列完整的雅可比矩阵,而获得每一列的成本与一次函数评估的成本大致相同。这对于具有“高”雅可比矩阵的函数是有效的,但对于“宽”雅可比矩阵则效率低下。
如果您正在机器学习中进行基于梯度的优化,您可能希望最小化一个损失函数,该函数的参数从 \(\mathbb{R}^n\) 中的参数映射到 \(\mathbb{R}\) 中的标量损失值。这意味着此函数的雅可比矩阵是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其等同于梯度向量 \(\nabla f(x) \in \mathbb{R}^n\)。逐列构建该矩阵,每次调用花费的 FLOPs 数量与评估原始函数相似,这似乎效率低下!特别是对于训练神经网络,其中 \(f\) 是训练损失函数,而 \(n\) 可能达到数百万或数十亿,这种方法根本无法扩展。
为了更好地处理像这样的函数,我们只需要使用反向模式。
向量-雅可比积(VJPs,又名反向模式自动微分)#
正向模式为我们提供了一个用于评估雅可比-向量积的函数,然后我们可以使用它来逐列构建雅可比矩阵;而反向模式是一种获取用于评估向量-雅可比积(等效于雅可比-转置-向量积)的函数的方法,我们可以使用它来逐行构建雅可比矩阵。
VJPs 的数学表示#
让我们再次考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从我们对 JVPs 的表示法开始,VJPs 的表示法非常简单
\(\qquad (x, v) \mapsto v \partial f(x)\),
其中 \(v\) 是 \(f\) 在 \(x\) 处的余切空间的一个元素(同构于 \(\mathbb{R}^m\) 的另一个副本)。当严格来说,我们应该将 \(v\) 视为线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),当我们写 \(v \partial f(x)\) 时,我们指的是函数组合 \(v \circ \partial f(x)\),其中类型匹配,因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见情况下,我们可以将 \(v\) 与 \(\mathbb{R}^m\) 中的向量等同起来,并且几乎可以互换使用这两个概念,就像我们有时可能会在“列向量”和“行向量”之间切换而无需过多解释一样。
通过这种等同,我们可以将 VJP 的线性部分另类地视为 JVP 的线性部分的转置(或伴随共轭)
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
对于给定的点 \(x\),我们可以将签名写为
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
余切空间上的对应映射通常被称为 拉回 \(f\) 在 \(x\) 处的值。对于我们的目的来说,关键是它从看起来像 \(f\) 的输出的东西变为看起来像 \(f\) 的输入的东西,正如我们可能从转置线性函数中期望的那样。
JAX 代码中的 VJP#
从数学切换回 Python,JAX 函数 vjp
可以接受一个用于评估 \(f\) 的 Python 函数,并返回一个用于评估 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
用 类似 Haskell 的类型签名,我们可以写成
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
其中我们使用 CT a
来表示 a
的余切空间的类型。换句话说,vjp
接受一个类型为 a -> b
的函数和一个类型为 a
的点作为参数,并返回一个由类型为 b
的值和类型为 CT b -> CT a
的线性映射组成的对。
这很棒,因为它允许我们逐行构建雅可比矩阵,并且评估 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOPs 成本仅约为评估 \(f\) 的成本的三倍。特别是,如果我们想要函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们只需调用一次即可完成。这就是为什么 grad
对于基于梯度的优化是高效的,即使对于数百万或数十亿参数的神经网络训练损失函数等目标也是如此。
但是,这也有成本:虽然 FLOPs 很友好,但内存会随着计算深度而扩展。此外,实现传统上比正向模式更复杂,尽管 JAX 有一些技巧(那是未来笔记本的故事!)。
有关反向模式如何工作的更多信息,请参阅 2017 年深度学习夏季学院的这个教程视频。
使用 VJP 的向量值梯度#
如果您对获取向量值梯度(如 tf.gradients
)感兴趣
from jax import vjp
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
[6. 6.]]
使用正向和反向模式的 Hessian-向量积#
在之前的章节中,我们仅使用反向模式实现了 Hessian-向量积函数(假设连续二阶导数)
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这很有效,但我们可以做得更好,并通过将正向模式与反向模式结合使用来节省一些内存。
在数学上,给定一个要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\),一个用于线性化函数的点 \(x \in \mathbb{R}^n\),以及一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量积函数是
\((x, v) \mapsto \partial^2 f(x) v\)
考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。我们只需要它的 JVP,因为它会给我们
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
我们可以几乎直接将其转换为代码
from jax import jvp, grad
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
更好的是,由于我们不必直接调用 jnp.dot
,因此这个 hvp
函数可以处理任何形状的数组和任意容器类型(例如存储为嵌套列表/字典/元组的向量),甚至不依赖于 jax.numpy
。
这是一个如何使用它的示例
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
您可能考虑的另一种编写方式是使用反向-覆盖-正向
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
但这不太好,因为正向模式的开销比反向模式小,并且由于外部微分运算符必须微分比内部微分运算符更大的计算,因此将正向模式放在外部效果最佳
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
2.11 ms ± 62.6 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
The slowest run took 6.36 times longer than the fastest. This could mean that an intermediate result is being cached.
7.21 ms ± 6.53 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
The slowest run took 6.00 times longer than the fastest. This could mean that an intermediate result is being cached.
10 ms ± 8.83 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
29.4 ms ± 1.28 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
组合 VJP、JVP 和 vmap
#
雅可比矩阵-矩阵和矩阵-雅可比矩阵积#
现在我们有了 jvp
和 vjp
变换,它们为我们提供了每次推送或拉回单个向量的函数,我们可以使用 JAX 的 vmap
变换一次推送和拉回整个基。特别是,我们可以使用它来编写快速的矩阵-雅可比矩阵和雅可比矩阵-矩阵积。
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
52.7 ms ± 381 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
2.6 ms ± 43.5 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1325/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
110 ms ± 349 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
1.53 ms ± 30.4 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
jacfwd
和 jacrev
的实现#
现在我们已经看到了快速的雅可比矩阵-矩阵和矩阵-雅可比矩阵积,不难猜测如何编写 jacfwd
和 jacrev
。我们只是使用相同的技术一次推送或拉回整个标准基(同构于单位矩阵)。
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
有趣的是,Autograd 做不到这一点。我们在 Autograd 中反向模式 jacobian
的 实现 必须使用外部循环 map
每次拉回一个向量。一次推送一个向量通过计算的效率远低于使用 vmap
将所有向量批量处理。
Autograd 做不到的另一件事是 jit
。有趣的是,无论您在要微分的函数中使用多少 Python 动态性,我们始终可以在计算的线性部分使用 jit
。例如
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
复数和微分#
JAX 在复数和微分方面非常出色。为了同时支持 全纯和非全纯微分,它有助于从 JVP 和 VJP 的角度思考。
考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\),并将其等同于相应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\),
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
也就是说,我们分解 \(f(z) = u(x, y) + v(x, y) i\),其中 \(z = x + y i\),并将 \(\mathbb{C}\) 与 \(\mathbb{R}^2\) 等同以获得 \(g\)。
由于 \(g\) 仅涉及实数输入和输出,我们已经知道如何为其编写雅可比矩阵-向量积,例如给定一个切向量 \((c, d) \in \mathbb{R}^2\),即
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
为了获得应用于切向量 \(c + di \in \mathbb{C}\) 的原始函数 \(f\) 的 JVP,我们只需使用相同的定义并将结果识别为另一个复数,
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数的 JVP 的定义!请注意,\(f\) 是否全纯无关紧要:JVP 是明确的。
这是一个检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
那么 VJP 呢?我们做一些非常相似的事情:对于余切向量 \(c + di \in \mathbb{C}\),我们将 \(f\) 的 VJP 定义为
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).
负号是怎么回事?它们只是为了处理复共轭,以及我们正在使用协向量的事实。
这是一个 VJP 规则的检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
那么像 grad
、jacfwd
和 jacrev
这样的便捷包装器呢?
对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下我们将 grad(f)(x)
定义为 vjp(f, x)[1](1.0)
,这之所以有效,是因为将 VJP 应用于 1.0
值会揭示梯度(即雅可比矩阵或导数)。对于 \(\mathbb{C} \to \mathbb{R}\) 函数,我们可以做同样的事情:我们仍然可以使用 1.0
作为余切向量,我们只会得到一个复数结果,总结了完整的雅可比矩阵
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,雅可比矩阵具有 4 个实数值自由度(如上面的 2x2 雅可比矩阵所示),因此我们不能希望在一个复数中表示所有这些自由度。但是对于全纯函数,我们可以!全纯函数恰好是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,其导数可以表示为单个复数。(柯西-黎曼方程 确保上述 2x2 雅可比矩阵具有复平面中缩放和旋转矩阵的特殊形式,即单个复数在乘法下的作用。)我们可以通过使用协向量 1.0
对 vjp
进行一次调用来揭示该复数。
因为这仅适用于全纯函数,所以要使用此技巧,我们需要向 JAX 保证我们的函数是全纯的;否则,当 grad
用于复数输出函数时,JAX 会引发错误
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)
所有 holomorphic=True
承诺所做的只是在输出为复数值时禁用错误。当函数不是全纯函数时,我们仍然可以编写 holomorphic=True
,但是我们得到的答案不会表示完整的雅可比矩阵。相反,它将是我们只丢弃输出的虚部的函数的雅可比矩阵
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
关于 grad
在这里的工作方式,有一些有用的结果
我们可以对全纯 \(\mathbb{C} \to \mathbb{C}\) 函数使用
grad
。我们可以通过在
grad(f)(x)
的共轭方向上采取步骤,使用grad
来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复数参数x
的实值损失函数。如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,它恰好在内部使用了一些复数值运算(其中一些运算必须是非全纯的,例如卷积中使用的 FFT),那么
grad
仍然有效,并且我们得到的结果与仅使用实数值的实现所给出的结果相同。
在任何情况下,JVP 和 VJP 始终是明确的。如果我们想计算非全纯 \(\mathbb{C} \to \mathbb{C}\) 函数的完整雅可比矩阵,我们可以使用 JVP 或 VJP 来完成!
您应该期望复数在 JAX 中的任何地方都有效。这是通过复矩阵的 Cholesky 分解进行微分
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
更高级的自动微分#
在本笔记本中,我们研究了 JAX 中自动微分的一些简单,然后逐渐变得更复杂的应用。我们希望您现在觉得在 JAX 中求导既容易又强大。
还有整个世界的其他自动微分技巧和功能。我们没有涵盖的主题,但希望在“高级自动微分食谱”中涵盖的主题包括
高斯-牛顿向量积,线性化一次
自定义 VJP 和 JVP
固定点的有效导数
使用随机 Hessian-向量积估计 Hessian 的迹。
仅使用反向模式自动微分的正向模式自动微分。
对自定义数据类型求导。
检查点(用于高效反向模式的二项式检查点,而不是模型快照)。
使用雅可比矩阵预累积优化 VJP。