高阶导数#
求导(第 2 部分)#
JAX 的自动微分使得计算高阶导数变得简单,因为计算导数的函数本身也是可微的。因此,高阶导数就像堆叠变换一样简单。
单变量情况已在 自动微分 教程中介绍,其中的示例展示了如何使用 jax.grad() 来计算 \(f(x) = x^3 + 2x^2 - 3x + 1\) 的导数。
在多变量情况下,高阶导数更为复杂。函数的二阶导数由其 Hessian 矩阵表示,其定义为:
多元实值函数 \(f: \mathbb R^n\to\mathbb R\) 的 Hessian 矩阵可以看作是其梯度的 雅可比矩阵 (Jacobian)。
JAX 提供了两种用于计算函数雅可比矩阵的变换:jax.jacfwd() 和 jax.jacrev(),分别对应前向模式和反向模式自动微分。它们给出相同的结果,但在不同情况下,其中一种可能比另一种更高效——请参阅 关于自动微分的视频。
import jax
def hessian(f):
return jax.jacfwd(jax.grad(f))
让我们在点积 \(f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}\) 上验证这一点是否正确。
如果 \(i=j\),则 \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2\)。否则,\(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0\)。
import jax.numpy as jnp
def f(x):
return jnp.dot(x, x)
hessian(f)(jnp.array([1., 2., 3.]))
Array([[2., 0., 0.],
[0., 2., 0.],
[0., 0., 2.]], dtype=float32)
高阶导数的应用#
一些元学习技术,例如模型无关元学习 (MAML),需要对梯度更新进行微分。在其他框架中这可能非常繁琐,但在 JAX 中却容易得多。
def meta_loss_fn(params, data):
"""Computes the loss after one step of SGD."""
grads = jax.grad(loss_fn)(params, data)
return loss_fn(params - lr * grads, data)
meta_grads = jax.grad(meta_loss_fn)(params, data)
停止梯度#
自动微分支持自动计算函数相对于其输入的梯度。然而,有时你可能需要额外的控制:例如,你可能希望避免通过计算图的某些子集进行反向传播。
以 TD(0) (时间差分) 强化学习更新为例。它用于根据与环境交互的经验来学习估计环境中状态的价值。假设状态 \(s_{t-1}\) 中的价值估计 \(v_{\theta}(s_{t-1})\) 由线性函数参数化。
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])
考虑从状态 \(s_{t-1}\) 到状态 \(s_t\) 的转换,在此期间观察到了奖励 \(r_t\)。
# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])
网络参数的 TD(0) 更新为:
此更新不是任何损失函数的梯度。
但是,它可以被写成伪损失函数的梯度:
前提是忽略目标 \(r_t + v_{\theta}(s_t)\) 对参数 \(\theta\) 的依赖。
如何在 JAX 中实现这一点?如果你天真地写出伪损失,你会得到:
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((target - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
Array([-1.2, 1.2, -1.2], dtype=float32)
但是 td_update 不会计算 TD(0) 更新,因为梯度计算将包含 target 对 \(\theta\) 的依赖。
你可以使用 jax.lax.stop_gradient() 来强制 JAX 忽略目标对 \(\theta\) 的依赖。
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
Array([ 1.2, 2.4, -1.2], dtype=float32)
这将把 target 视为不依赖于参数 \(\theta\),并计算正确的参数更新。
现在,让我们也使用原始的 TD(0) 更新表达式计算 \(\Delta \theta\),以进行交叉验证。你可能希望尝试使用 jax.grad() 和你目前所学的知识自行实现这一点。这是我们的解决方案:
s_grad = jax.grad(value_fn)(theta, s_tm1)
delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad
delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`
Array([ 1.2, 2.4, -1.2], dtype=float32)
jax.lax.stop_gradient 在其他设置中也可能很有用,例如,如果你希望来自某个损失的梯度只影响神经网络参数的子集(例如,因为其他参数是使用不同的损失进行训练的)。
使用 stop_gradient 的直通估计器#
直通估计器是一种用于定义不可微函数“梯度”的技巧。给定一个非微分函数 \(f : \mathbb{R}^n \to \mathbb{R}^n\),将其作为我们希望求梯度的更大函数的一部分使用,我们只需在反向传递期间假定 \(f\) 是恒等函数。这可以使用 jax.lax.stop_gradient 简洁地实现。
def f(x):
return jnp.round(x) # non-differentiable
def straight_through_f(x):
# Create an exactly-zero expression with Sterbenz lemma that has
# an exactly-one gradient.
zero = x - jax.lax.stop_gradient(x)
return zero + jax.lax.stop_gradient(f(x))
print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))
print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x): 3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0
逐样本梯度#
虽然大多数机器学习系统出于计算效率和/或方差减少的原因,从数据批次 (batch) 中计算梯度和更新,但有时需要获取与批次中每个特定样本相关的梯度/更新。
例如,这在根据梯度大小优先处理数据,或在逐样本的基础上应用裁剪/归一化时是必需的。
在许多框架(PyTorch, TF, Theano)中,计算逐样本梯度通常并不容易,因为库直接在批次上累加梯度。天真的变通方法(例如为每个样本计算单独的损失,然后聚合产生的梯度)通常效率极低。
在 JAX 中,你可以以一种简单且高效的方式定义计算逐样本梯度的代码。
只需将 jax.jit(), jax.vmap() 和 jax.grad() 变换结合起来。
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
让我们逐个变换进行说明。
首先,你将 jax.grad() 应用于 td_loss,以获得一个计算单个(非批处理)输入上损失相对于参数的梯度的函数。
dtdloss_dtheta = jax.grad(td_loss)
dtdloss_dtheta(theta, s_tm1, r_t, s_t)
Array([ 1.2, 2.4, -1.2], dtype=float32)
此函数计算上述数组的一行。
然后,你使用 jax.vmap() 对此函数进行向量化。这会给所有输入和输出添加一个批处理维度。现在,给定一个输入批次,你将生成一个输出批次——批次中的每个输出都对应于输入批次中相应成员的梯度。
almost_perex_grads = jax.vmap(dtdloss_dtheta)
batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
这还不是我们想要的,因为我们必须手动为该函数提供一批 theta,而我们实际上想使用单个 theta。我们通过向 jax.vmap() 添加 in_axes 来解决这个问题,将 theta 指定为 None,将其他参数指定为 0。这使得结果函数仅向其他参数添加一个额外的轴,使 theta 保持不批处理状态,正如我们所希望的那样。
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
这实现了我们的需求,但比原本可以达到的速度要慢。现在,将整个内容包装在 jax.jit() 中,以获得该函数的编译且高效的版本。
perex_grads = jax.jit(inefficient_perex_grads)
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
3.01 ms ± 6.13 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.33 μs ± 8.76 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
使用 jax.grad 之 jax.grad 进行 Hessian 向量积#
你可以使用高阶 jax.grad() 做的一件事是构建 Hessian 向量积函数。(稍后你将编写一个更高效的实现,结合前向和反向模式,但这一个将使用纯反向模式。)
Hessian 向量积函数在用于最小化平滑凸函数的 截断牛顿共轭梯度算法 中,或者用于研究神经网络训练目标的曲率时非常有用(例如 1, 2, 3, 4)。
对于具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\)(使得 Hessian 矩阵是对称的),点 \(x \in \mathbb{R}^n\) 处的 Hessian 记为 \(\partial^2 f(x)\)。那么 Hessian 向量积函数能够计算
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
对于任何 \(v \in \mathbb{R}^n\)。
诀窍在于不要实例化完整的 Hessian 矩阵:如果 \(n\) 很大,在神经网络的上下文中可能是数百万或数十亿,那么存储它可能是不可能的。
幸运的是,jax.grad() 已经为我们提供了一种编写高效 Hessian 向量积函数的方法。你只需要利用恒等式:
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\) 在 \(x\) 处的梯度与向量 \(v\) 点积。请注意,你只是在对向量参数的标量值函数进行微分,这正是你知道 jax.grad() 高效的地方。
在 JAX 代码中,你可以直接写成:
def hvp(f, x, v):
return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x)
这个例子表明,你可以自由地使用词法闭包,而 JAX 永远不会因此感到困惑。
你将在下方几个单元格中检查此实现,一旦你学会了如何计算稠密 Hessian 矩阵。你还将编写一个使用前向模式和反向模式的更好版本。
使用 jax.jacfwd 和 jax.jacrev 计算雅可比矩阵和 Hessian 矩阵#
你可以使用 jax.jacfwd() 和 jax.jacrev() 函数计算完整的雅可比矩阵。
from jax import jacfwd, jacrev
# Define a sigmoid function.
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
# Initialize random model coefficients
key = jax.random.key(0)
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415 0.1091874 0.07506633]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447992]
[ 0.00574409 -0.0193281 0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415 0.10918741 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
这两个函数计算相同的值(在机器数值精度范围内),但在实现上有所不同:jax.jacfwd() 使用前向模式自动微分,这对于“高”雅可比矩阵(输出多于输入)更高效,而 jax.jacrev() 使用反向模式,这对于“宽”雅可比矩阵(输入多于输出)更高效。对于近乎方形的矩阵,jax.jacfwd() 可能比 jax.jacrev() 略有优势。
你也可以在容器类型中使用 jax.jacfwd() 和 jax.jacrev()。
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jax.jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05069415 0.10918741 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
Jacobian from b to logits is
[0.09748876 0.16102302 0.24190766 0.00776229]
有关前向和反向模式的更多详细信息,以及如何尽可能高效地实现 jax.jacfwd() 和 jax.jacrev(),请继续阅读!
使用其中两个函数的组合,我们就可以计算稠密 Hessian 矩阵。
def hessian(f):
return jax.jacfwd(jax.jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932 0.04434624 0.03048803]
[ 0.04434623 0.09551499 0.06566654]
[ 0.03048803 0.06566655 0.04514575]]
[[-0.0743913 0.09129842 -0.01268033]
[ 0.09129842 -0.11204806 0.01556223]
[-0.01268034 0.01556223 -0.00216142]]
[[ 0.01176856 0.00135791 -0.02942139]
[ 0.00135791 0.00015668 -0.00339478]
[-0.0294214 -0.00339478 0.07355348]]
[[-0.00418412 0.014079 -0.00785936]
[ 0.014079 -0.04737393 0.02644569]
[-0.00785936 0.02644569 -0.01476286]]]
这种形状是有道理的:如果你从函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,你期望得到以下形状:
\(f(x) \in \mathbb{R}^m\),即 \(f\) 在 \(x\) 处的值,
\(\partial f(x) \in \mathbb{R}^{m \times n}\),即 \(x\) 处的雅可比矩阵,
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),即 \(x\) 处的 Hessian 矩阵,
依此类推。
要实现 hessian,你可以使用 jacfwd(jacrev(f)) 或 jacrev(jacfwd(f)) 或这两者的任何其他组合。但前向-反向(forward-over-reverse)通常最有效。这是因为在内部雅可比矩阵计算中,我们通常是对宽雅可比函数(例如损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\))进行微分,而在外部雅可比矩阵计算中,我们是对具有方形雅可比矩阵的函数进行微分(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这正是前向模式胜出的地方。