The Autodiff Cookbook#
JAX 拥有一个相当通用的自动微分系统。在本 notebook 中,我们将通过一系列有趣自动微分的思路,从基础开始,供您选择使用。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
梯度#
从 grad 开始#
您可以使用 grad 对函数进行微分
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
grad 接受一个函数并返回一个函数。如果您有一个 Python 函数 f 用于计算数学函数 \(f\),那么 grad(f) 就是一个 Python 函数,用于计算数学函数 \(\nabla f\)。这意味着 grad(f)(x) 代表了值 \(\nabla f(x)\)。
由于 grad 操作的是函数,您可以将其应用于自身的输出来进行任意多次微分
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
让我们通过一个线性逻辑回归模型来计算使用 grad 计算梯度。首先,设置部分
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
使用带有 argnums 参数的 grad 函数,以相对于位置参数进行微分。
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.433146 -0.7354605 -1.2598922]
W_grad [-0.433146 -0.7354605 -1.2598922]
b_grad -0.69001776
W_grad [-0.433146 -0.7354605 -1.2598922]
b_grad -0.69001776
这个 grad API 直接对应于 Spivak 经典著作《流形上的微积分》(1965) 中出色的表示法,该表示法也用于 Sussman 和 Wisdom 的《经典力学的结构与解释》(2015) 以及他们的《函数微分几何》(2013)。这两本书都提供了开放获取。特别参见《函数微分几何》的“序言”部分,其中对这种表示法进行了辩护。
基本上,当使用 argnums 参数时,如果 f 是一个用于计算数学函数 \(f\) 的 Python 函数,那么 Python 表达式 grad(f, i) 的计算结果是一个用于计算 \(\partial_i f\) 的 Python 函数。
关于嵌套列表、元组和字典进行微分#
对标准 Python 容器进行微分是有效的,因此可以按您喜欢的方式使用元组、列表和字典(以及任意嵌套)。
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32), 'b': Array(-0.69001776, dtype=float32)}
您可以注册您自己的容器类型,以便不仅能与 grad 一起使用,还能与所有 JAX 变换(jit、vmap 等)一起使用。
使用 value_and_grad 评估函数及其梯度#
另一个方便的函数是 value_and_grad,用于高效地计算函数值及其梯度值
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 2.9729187
loss value 2.9729187
与数值差进行核对#
关于导数的一个很棒的地方是,使用有限差值可以轻松地进行核对
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.6890297
b_grad_autodiff -0.69001776
W_dirderiv_numerical 1.3041496
W_dirderiv_autodiff 1.3006744
JAX 提供了一个简单的便利函数,其功能基本相同,但可以核对任意阶数的微分
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
使用 grad-of-grad 进行 Hessian-向量乘积#
通过高阶 grad 我们可以构建一个 Hessian-向量乘积函数。(稍后我们将编写一个更有效的实现,该实现混合了前向模式和反向模式,但这个版本将仅使用反向模式。)
Hessian-向量乘积函数在最小化光滑凸函数的截断牛顿共轭梯度算法中很有用,或者用于研究神经网络训练目标的曲率(例如,1、2、3、4)。
对于具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\)(因此 Hessian 矩阵是对称的),在点 \(x \in \mathbb{R}^n\) 处的 Hessian 表示为 \(\partial^2 f(x)\)。然后,Hessian-向量乘积函数能够计算
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
对于任何 \(v \in \mathbb{R}^n\)。
关键在于不实例化完整的 Hessian 矩阵:如果 \(n\) 很大,可能达到数百万甚至数十亿(在神经网络的上下文中),那么存储它可能是不可行的。
幸运的是,grad 已经为我们提供了一种编写高效 Hessian-向量乘积函数的方法。我们只需要使用以下恒等式
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
其中 \(g(x) = \partial f(x) \cdot v\) 是一个新构造的标量值函数,它将 \(f\) 在 \(x\) 处的梯度与向量 \(v\) 进行点积。请注意,我们只对向量值自变量的标量值函数进行微分,这正是我们知道 grad 高效之处。
在 JAX 代码中,我们可以这样写
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这个例子表明,您可以自由使用词法闭包,JAX 永远不会受到干扰或混淆。
我们将在下面的几个单元格中检查此实现,一旦我们了解了如何计算密集 Hessian 矩阵。我们还将编写一个更优化的版本,该版本同时使用前向模式和反向模式。
使用 jacfwd 和 jacrev 计算 Jacobian 和 Hessian#
您可以使用 jacfwd 和 jacrev 函数计算完整的 Jacobian 矩阵
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415 0.1091874 0.07506633]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447992]
[ 0.00574409 -0.0193281 0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415 0.10918741 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
这两个函数计算相同的值(在机器精度范围内),但在实现上有所不同:jacfwd 使用前向模式自动微分,这对于“高瘦”的 Jacobian 矩阵(输出多于输入)更有效;而 jacrev 使用反向模式,这对于“宽胖”的 Jacobian 矩阵(输入多于输出)更有效。对于接近方形的矩阵,jacfwd 可能比 jacrev 更有优势。
您也可以将 jacfwd 和 jacrev 与容器类型一起使用
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05069415 0.10918741 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
Jacobian from b to logits is
[0.09748876 0.16102302 0.24190766 0.00776229]
有关前向模式和反向模式的更多详细信息,以及如何尽可能高效地实现 jacfwd 和 jacrev,请继续阅读!
通过组合使用这两个函数,我们可以计算密集 Hessian 矩阵
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932 0.04434624 0.03048803]
[ 0.04434623 0.09551499 0.06566654]
[ 0.03048803 0.06566655 0.04514575]]
[[-0.0743913 0.09129842 -0.01268033]
[ 0.09129842 -0.11204806 0.01556223]
[-0.01268034 0.01556223 -0.00216142]]
[[ 0.01176856 0.00135791 -0.02942139]
[ 0.00135791 0.00015668 -0.00339478]
[-0.0294214 -0.00339478 0.07355348]]
[[-0.00418412 0.014079 -0.00785936]
[ 0.014079 -0.04737393 0.02644569]
[-0.00785936 0.02644569 -0.01476286]]]
这个形状是有意义的:如果我们从一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,我们期望得到以下形状
\(f(x) \in \mathbb{R}^m\),\(f\) 在 \(x\) 处的值;
\(\partial f(x) \in \mathbb{R}^{m \times n}\),在 \(x\) 处的 Jacobian 矩阵;
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),在 \(x\) 处的 Hessian;
依此类推。
为了实现 hessian,我们可以使用 jacfwd(jacrev(f)) 或 jacrev(jacfwd(f)) 或它们的任何其他组合。但前向模式套反向模式通常是最有效的。这是因为在内部 Jacobian 计算中,我们经常对具有宽 Jacobian 的函数进行微分(可能像损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\)),而在外部 Jacobian 计算中,我们对具有方形 Jacobian 的函数进行微分(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这时前向模式占优势。
它是如何工作的:两个基础的自动微分函数#
Jacobian-向量乘积(JVP,也称为前向模式自动微分)#
JAX 包含前向模式和反向模式自动微分的高效通用实现。grad 函数建立在反向模式的基础上,但要解释这两种模式的区别以及何时使用它们,我们需要一些数学背景。
数学中的 JVP#
在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\),在输入点 \(x \in \mathbb{R}^n\) 处计算的 \(f\) 的 Jacobian,记作 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的一个矩阵
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
但我们也可以将 \(\partial f(x)\) 视为一个线性映射,它将 \(f\) 在点 \(x\) 处的切空间(它就是另一个 \(\mathbb{R}^n\) 的副本)映射到 \(f\) 的共域在点 \(f(x)\) 处的切空间(\(\mathbb{R}^m\) 的副本)
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
这个映射被称为 \(f\) 在 \(x\) 处的推前映射。Jacobian 矩阵就是标准基下这个线性映射的矩阵。
如果我们不固定一个特定的输入点 \(x\),那么我们可以将函数 \(\partial f\) 视为先接受一个输入点,然后返回该输入点处的 Jacobian 线性映射
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
特别地,我们可以取消柯里化,这样给定输入点 \(x \in \mathbb{R}^n\) 和切向量 \(v \in \mathbb{R}^n\),我们可以得到一个输出切向量 \(\mathbb{R}^m\)。我们将从 \((x, v)\) 对到输出切向量的映射称为Jacobian-向量乘积,并将其写为
\(\qquad (x, v) \mapsto \partial f(x) v\)
JAX 代码中的 JVP#
回到 Python 代码,JAX 的 jvp 函数模拟了这种变换。给定一个用于计算 \(f\) 的 Python 函数,JAX 的 jvp 可以得到一个用于计算 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的 Python 函数。
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
用类似 Haskell 的类型签名来说,我们可以写成
jvp :: (a -> b) -> a -> T a -> (b, T b)
其中我们使用 T a 来表示 a 的切空间类型。通俗地说,jvp 接受的参数是一个类型为 a -> b 的函数,一个类型为 a 的值,以及一个类型为 T a 的切向量值。它返回一个对,其中包含一个类型为 b 的值和一个类型为 T b 的输出切向量。
jvp 变换后的函数与原始函数类似地计算,但与每个类型为 a 的原始值配对,它沿着类型为 T a 的切向量值进行传播。对于原始函数将应用的每个基本数值运算,jvp 变换后的函数会执行该基本运算的“JVP 规则”,该规则同时在原始值上计算基本运算,并应用基本运算的 JVP。
这种计算策略对计算复杂度有一些直接的影响:由于我们边计算边评估 JVP,因此无需存储任何东西以供将来使用,内存成本与计算深度无关。此外,jvp 变换后函数的 FLOP 成本大约是仅评估函数成本的 3 倍(评估原始函数的工作量为 1,例如 sin(x);线性化的工作量为 1,例如 cos(x);以及将线性化函数应用于向量的工作量为 1,例如 cos_x * v)。换句话说,对于固定的原始点 \(x\),我们可以以与评估 \(f\) 相似的边际成本来计算 \(v \mapsto \partial f(x) \cdot v\)。
这个内存复杂度听起来相当诱人!那么为什么我们在机器学习中很少看到前向模式呢?
要回答这个问题,首先考虑如何使用 JVP 来构建完整的 Jacobian 矩阵。如果我们对一个单热(one-hot)切向量应用 JVP,它会揭示 Jacobian 矩阵的一列,对应于我们输入的非零条目。因此,我们可以逐列构建完整的 Jacobian,并且获取每一列的成本与一次函数求值相似。这对于具有“高瘦”Jacobian 的函数会很有效,但对于“宽胖”Jacobian 的函数则效率低下。
如果您正在进行机器学习中的基于梯度的优化,您可能希望最小化一个从参数 \(\mathbb{R}^n\) 到标量损失值 \(\mathbb{R}\) 的损失函数。这意味着该函数的 Jacobian 是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们经常将其标识为梯度向量 \(\nabla f(x) \in \mathbb{R}^n\)。逐列构建该矩阵,每次调用都需要与求原始函数相似的 FLOPs,这看起来效率低下!特别是在训练神经网络时,其中 \(f\) 是一个训练损失函数,而 \(n\) 可能达到数百万甚至数十亿,这种方法根本无法扩展。
为了更好地处理这类函数,我们只需要使用反向模式。
向量-Jacobian 乘积(VJP,也称为反向模式自动微分)#
前向模式返回一个用于计算 Jacobian-向量乘积的函数,然后我们可以用它来逐列构建 Jacobian 矩阵;而反向模式则是一种返回用于计算向量-Jacobian 乘积(等效于 Jacobian-transpose-向量乘积)的函数的方法,然后我们可以用它来逐行构建 Jacobian 矩阵。
数学中的 VJP#
我们再次考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从 JVP 的表示法开始,VJP 的表示法非常简单
\(\qquad (x, v) \mapsto v \partial f(x)\),
其中 \(v\) 是 \(f\) 在 \(x\) 处的余切空间(同构于 \(\mathbb{R}^m\) 的另一个副本)的一个元素。在严谨的场合,我们应该将 \(v\) 视为一个线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),当我们写 \(v \partial f(x)\) 时,我们指的是函数复合 \(v \circ \partial f(x)\),其中类型匹配是因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见情况下,我们可以将 \(v\) 与 \(\mathbb{R}^m\) 中的向量进行同一化,并近乎可互换地使用它们,就像我们有时会不加评论地在“列向量”和“行向量”之间切换一样。
有了这个同一化,我们也可以将 VJP 的线性部分视为 JVP 的线性部分的转置(或伴随共轭)
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
对于给定的点 \(x\),我们可以写出签名
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
对应于余切空间的映射通常被称为 \(f\) 在 \(x\) 处的拉回映射。对我们而言关键在于,它从看起来像 \(f\) 的输出的东西映射到看起来像 \(f\) 的输入的东西,正如我们可能从转置的线性函数中期望的那样。
JAX 代码中的 VJP#
从数学切换回 Python,JAX 函数 vjp 可以接受一个用于计算 \(f\) 的 Python 函数,并返回一个用于计算 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
用类似 Haskell 的类型签名来说,我们可以写成
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
其中我们使用 CT a 来表示 a 的余切空间类型。通俗地说,vjp 接受的参数是一个类型为 a -> b 的函数和一个类型为 a 的点,并返回一个对,其中包含一个类型为 b 的值和一个类型为 CT b -> CT a 的线性映射。
这很好,因为它允许我们逐行构建 Jacobian 矩阵,并且计算 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOPs 成本仅为计算 \(f\) 的成本的约三倍。特别是,如果我们想要一个函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们可以在一次调用中完成。这就是 grad 对于基于梯度的优化来说效率高的原因,即使对于像神经网络训练损失函数那样具有数百万或数十亿参数的目标也是如此。
但有一个代价:虽然 FLOPs 对我们有利,但内存会随着计算深度的增加而扩展。此外,它的实现通常比前向模式更复杂,尽管 JAX 有一些绝招(那是以后 notebook 的故事!)。
有关反向模式工作原理的更多信息,请参阅 2017 年深度学习暑期学校的这个教程视频。
使用 VJP 进行向量值梯度#
如果您对计算向量值梯度(例如 tf.gradients)感兴趣
from jax import vjp
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
[6. 6.]]
使用前向模式和反向模式进行 Hessian-向量乘积#
在前面的部分,我们仅使用反向模式实现了一个 Hessian-向量乘积函数(假设二阶导数连续)
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这很高效,但我们可以做得更好,通过同时使用前向模式和反向模式来节省一些内存。
在数学上,给定一个需要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\),一个需要线性化的点 \(x \in \mathbb{R}^n\),以及一个向量 \(v \in \mathbb{R}^n\),我们想要实现的 Hessian-向量乘积函数是
\((x, v) \mapsto \partial^2 f(x) v\)
考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),它被定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。我们只需要它的 JVP,因为它将给出
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
我们可以几乎直接地将其转换为代码
from jax import jvp, grad
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
更妙的是,由于我们不必直接调用 jnp.dot,这个 hvp 函数可以处理任意形状的数组和任意容器类型(如存储在嵌套列表/字典/元组中的向量),甚至不依赖于 jax.numpy。
以下是使用它的示例
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
您可能考虑的另一种写法是使用反向模式套前向模式
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
这不算太好,因为前向模式的开销比反向模式小,而且由于这里的外部微分算子必须微分比内部算子更大的计算量,所以将前向模式放在外面效果最好
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
3.07 ms ± 123 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
The slowest run took 7.44 times longer than the fastest. This could mean that an intermediate result is being cached.
14.6 ms ± 14 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
The slowest run took 4.53 times longer than the fastest. This could mean that an intermediate result is being cached.
16.7 ms ± 12.7 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
44.3 ms ± 2.08 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
组合 VJP、JVP 和 vmap#
Jacobian-矩阵和矩阵-Jacobian 乘积#
既然我们有了 jvp 和 vjp 变换,它们分别提供了每次推送单个向量或拉回单个向量的函数,我们可以使用 JAX 的 vmap 变换一次性推送和拉回整个基。特别是,我们可以用它来快速编写矩阵-Jacobian 和 Jacobian-矩阵乘积。
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
46.3 ms ± 132 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
3.16 ms ± 48.1 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_2012/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
88.7 ms ± 114 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
1.39 ms ± 46.3 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
jacfwd 和 jacrev 的实现#
既然我们已经看到了快速的 Jacobian-矩阵和矩阵-Jacobian 乘积,那么推测如何编写 jacfwd 和 jacrev 并不难。我们只需要使用相同的技术一次性地推送或拉回整个标准基(同构于单位矩阵)。
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
有趣的是,Autograd 无法做到这一点。我们在 Autograd 中实现的反向模式 jacobian 必须通过外部循环 map 一次拉回一个向量。逐个向量地通过计算进行推送比使用 vmap 进行批量处理效率低得多。
Autograd 无法做的另一件事是 jit。有趣的是,无论您的被微分函数中使用多少 Python 动态性,我们总可以使用 jit 来处理计算的线性部分。例如
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
复数与微分#
JAX 在复数和微分方面表现出色。为了支持全纯和非全纯微分,最好从 JVP 和 VJP 的角度来思考。
考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\),并将其与对应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\) 进行同一化,
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
也就是说,我们将 \(f(z) = u(x, y) + v(x, y) i\) 分解,其中 \(z = x + y i\),并识别 \(\mathbb{C}\) 与 \(\mathbb{R}^2\) 以获得 \(g\)。
由于 \(g\) 只涉及实数输入和输出,我们已经知道如何为其编写 Jacobian-向量乘积,例如给定一个切向量 \((c, d) \in \mathbb{R}^2\),即
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
为了获得原始函数 \(f\) 应用于切向量 \(c + di \in \mathbb{C}\) 的 JVP,我们只需使用相同的定义,并将结果识别为另一个复数:
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数的 JVP 定义!请注意,\(f\) 是否全纯并不重要:JVP 是明确的。
这里有一个检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
VJP 呢?我们做类似的事情:对于余切向量 \(c + di \in \mathbb{C}\),我们定义 \(f\) 的 VJP 为
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).
为什么会有负号?它们只是为了处理复共轭以及我们正在处理共向量的事实。
这里有一个 VJP 规则的检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
那么像 grad、jacfwd 和 jacrev 这样的方便包装器呢?
对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下我们定义 grad(f)(x) 为 vjp(f, x)[1](1.0),这是有效的,因为将 VJP 应用于 1.0 值会揭示梯度(即 Jacobian,或导数)。我们可以对 \(\mathbb{C} \to \mathbb{R}\) 函数做同样的事情:我们可以仍然使用 1.0 作为共向量,我们得到一个复数结果来总结完整的 Jacobian
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,Jacobian 有 4 个实值自由度(如上面的 2x2 Jacobian 矩阵所示),因此我们无法在一个复数中表示所有这些自由度。但对于全纯函数,我们可以!全纯函数本质上是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,它具有特殊的性质,即其导数可以表示为一个复数。(柯西-黎曼方程确保上面的 2x2 Jacobian 具有在复平面中按比例缩放和旋转矩阵的特殊形式,即乘以一个复数的运算。)我们可以通过一次调用 vjp 并使用共向量 1.0 来揭示这个复数。
由于这只适用于全纯函数,因此在使用此技巧时,我们需要向 JAX 承诺我们的函数是全纯的;否则,当 grad 用于复数输出函数时,JAX 将引发错误
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)
所有 holomorphic=True 的承诺只是在输出为复值时禁用错误。我们仍然可以对非全纯函数使用 holomorphic=True,但我们得到的结果将不代表完整的 Jacobian。相反,它将是函数的 Jacobian,其中我们仅丢弃输出的虚部
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
对于 grad 在这里如何工作,有一些有用的推论
我们可以在全纯 \(\mathbb{C} \to \mathbb{C}\) 函数上使用
grad。我们可以使用
grad来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复参数x的实值损失函数,通过在grad(f)(x)的共轭方向上进行步进。如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,它恰好在内部使用了某些复数值操作(其中一些必须是非全纯的,例如卷积中使用的 FFT),那么
grad仍然有效,并且我们得到的结果与仅使用实值进行计算的结果相同。
无论哪种情况,JVP 和 VJP 始终是明确的。如果我们想计算一个非全纯 \(\mathbb{C} \to \mathbb{C}\) 函数的完整 Jacobian 矩阵,我们可以使用 JVP 或 VJP 来做到这一点!
您应该期望复数在 JAX 中无处不在。这里是计算复数矩阵的 Cholesky 分解的微分
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
更高级的自动微分#
在本 notebook 中,我们通过一些简单的,然后是越来越复杂的 JAX 自动微分应用。我们希望您现在觉得在 JAX 中求导既简单又强大。
还有许多其他自动微分技巧和功能。我们没有涵盖但希望在“高级自动微分 Cookbook”中涵盖的主题包括
Gauss-Newton 向量乘积,一次线性化
自定义 VJP 和 JVP
固定点处的有效导数
使用随机 Hessian-向量乘积估计 Hessian 的迹。
仅使用反向模式自动微分的前向模式自动微分。
对自定义数据类型进行微分。
检查点(用于高效反向模式的二项式检查点,而不是模型快照)。
通过 Jacobian 预累加优化 VJP。