高阶导数#

求导(第 2 部分)#

JAX 的自动微分使得计算高阶导数变得简单,因为计算导数的函数本身也是可微的。因此,高阶导数就像堆叠变换一样简单。

单变量情况已在 自动微分 教程中介绍,其中的示例展示了如何使用 jax.grad() 来计算 \(f(x) = x^3 + 2x^2 - 3x + 1\) 的导数。

在多变量情况下,高阶导数更为复杂。函数的二阶导数由其 Hessian 矩阵表示,其定义为:

\[(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.\]

多元实值函数 \(f: \mathbb R^n\to\mathbb R\) 的 Hessian 矩阵可以看作是其梯度的 雅可比矩阵 (Jacobian)

JAX 提供了两种用于计算函数雅可比矩阵的变换:jax.jacfwd()jax.jacrev(),分别对应前向模式和反向模式自动微分。它们给出相同的结果,但在不同情况下,其中一种可能比另一种更高效——请参阅 关于自动微分的视频

import jax

def hessian(f):
  return jax.jacfwd(jax.grad(f))

让我们在点积 \(f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}\) 上验证这一点是否正确。

如果 \(i=j\),则 \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2\)。否则,\(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0\)

import jax.numpy as jnp

def f(x):
  return jnp.dot(x, x)

hessian(f)(jnp.array([1., 2., 3.]))
Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

高阶导数的应用#

一些元学习技术,例如模型无关元学习 (MAML),需要对梯度更新进行微分。在其他框架中这可能非常繁琐,但在 JAX 中却容易得多。

def meta_loss_fn(params, data):
  """Computes the loss after one step of SGD."""
  grads = jax.grad(loss_fn)(params, data)
  return loss_fn(params - lr * grads, data)

meta_grads = jax.grad(meta_loss_fn)(params, data)

停止梯度#

自动微分支持自动计算函数相对于其输入的梯度。然而,有时你可能需要额外的控制:例如,你可能希望避免通过计算图的某些子集进行反向传播。

以 TD(0) (时间差分) 强化学习更新为例。它用于根据与环境交互的经验来学习估计环境中状态的价值。假设状态 \(s_{t-1}\) 中的价值估计 \(v_{\theta}(s_{t-1})\) 由线性函数参数化。

# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

考虑从状态 \(s_{t-1}\) 到状态 \(s_t\) 的转换,在此期间观察到了奖励 \(r_t\)

# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

网络参数的 TD(0) 更新为:

\[ \Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) \]

此更新不是任何损失函数的梯度。

但是,它可以被写成伪损失函数的梯度:

\[ L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 \]

前提是忽略目标 \(r_t + v_{\theta}(s_t)\) 对参数 \(\theta\) 的依赖。

如何在 JAX 中实现这一点?如果你天真地写出伪损失,你会得到:

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return -0.5 * ((target - v_tm1) ** 2)

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
Array([-1.2,  1.2, -1.2], dtype=float32)

但是 td_update 不会计算 TD(0) 更新,因为梯度计算将包含 target\(\theta\) 的依赖。

你可以使用 jax.lax.stop_gradient() 来强制 JAX 忽略目标对 \(\theta\) 的依赖。

def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta
Array([ 1.2,  2.4, -1.2], dtype=float32)

这将把 target 视为依赖于参数 \(\theta\),并计算正确的参数更新。

现在,让我们也使用原始的 TD(0) 更新表达式计算 \(\Delta \theta\),以进行交叉验证。你可能希望尝试使用 jax.grad() 和你目前所学的知识自行实现这一点。这是我们的解决方案:

s_grad = jax.grad(value_fn)(theta, s_tm1)
delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad

delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`
Array([ 1.2,  2.4, -1.2], dtype=float32)

jax.lax.stop_gradient 在其他设置中也可能很有用,例如,如果你希望来自某个损失的梯度只影响神经网络参数的子集(例如,因为其他参数是使用不同的损失进行训练的)。

使用 stop_gradient 的直通估计器#

直通估计器是一种用于定义不可微函数“梯度”的技巧。给定一个非微分函数 \(f : \mathbb{R}^n \to \mathbb{R}^n\),将其作为我们希望求梯度的更大函数的一部分使用,我们只需在反向传递期间假定 \(f\) 是恒等函数。这可以使用 jax.lax.stop_gradient 简洁地实现。

def f(x):
  return jnp.round(x)  # non-differentiable

def straight_through_f(x):
  # Create an exactly-zero expression with Sterbenz lemma that has
  # an exactly-one gradient.
  zero = x - jax.lax.stop_gradient(x)
  return zero + jax.lax.stop_gradient(f(x))

print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x):  3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0

逐样本梯度#

虽然大多数机器学习系统出于计算效率和/或方差减少的原因,从数据批次 (batch) 中计算梯度和更新,但有时需要获取与批次中每个特定样本相关的梯度/更新。

例如,这在根据梯度大小优先处理数据,或在逐样本的基础上应用裁剪/归一化时是必需的。

在许多框架(PyTorch, TF, Theano)中,计算逐样本梯度通常并不容易,因为库直接在批次上累加梯度。天真的变通方法(例如为每个样本计算单独的损失,然后聚合产生的梯度)通常效率极低。

在 JAX 中,你可以以一种简单且高效的方式定义计算逐样本梯度的代码。

只需将 jax.jit(), jax.vmap()jax.grad() 变换结合起来。

perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)

让我们逐个变换进行说明。

首先,你将 jax.grad() 应用于 td_loss,以获得一个计算单个(非批处理)输入上损失相对于参数的梯度的函数。

dtdloss_dtheta = jax.grad(td_loss)

dtdloss_dtheta(theta, s_tm1, r_t, s_t)
Array([ 1.2,  2.4, -1.2], dtype=float32)

此函数计算上述数组的一行。

然后,你使用 jax.vmap() 对此函数进行向量化。这会给所有输入和输出添加一个批处理维度。现在,给定一个输入批次,你将生成一个输出批次——批次中的每个输出都对应于输入批次中相应成员的梯度。

almost_perex_grads = jax.vmap(dtdloss_dtheta)

batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)

这还不是我们想要的,因为我们必须手动为该函数提供一批 theta,而我们实际上想使用单个 theta。我们通过向 jax.vmap() 添加 in_axes 来解决这个问题,将 theta 指定为 None,将其他参数指定为 0。这使得结果函数仅向其他参数添加一个额外的轴,使 theta 保持不批处理状态,正如我们所希望的那样。

inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))

inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)

这实现了我们的需求,但比原本可以达到的速度要慢。现在,将整个内容包装在 jax.jit() 中,以获得该函数的编译且高效的版本。

perex_grads = jax.jit(inefficient_perex_grads)

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
3.01 ms ± 6.13 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.33 μs ± 8.76 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

使用 jax.gradjax.grad 进行 Hessian 向量积#

你可以使用高阶 jax.grad() 做的一件事是构建 Hessian 向量积函数。(稍后你将编写一个更高效的实现,结合前向和反向模式,但这一个将使用纯反向模式。)

Hessian 向量积函数在用于最小化平滑凸函数的 截断牛顿共轭梯度算法 中,或者用于研究神经网络训练目标的曲率时非常有用(例如 1, 2, 3, 4)。

对于具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\)(使得 Hessian 矩阵是对称的),点 \(x \in \mathbb{R}^n\) 处的 Hessian 记为 \(\partial^2 f(x)\)。那么 Hessian 向量积函数能够计算

\(\qquad v \mapsto \partial^2 f(x) \cdot v\)

对于任何 \(v \in \mathbb{R}^n\)

诀窍在于不要实例化完整的 Hessian 矩阵:如果 \(n\) 很大,在神经网络的上下文中可能是数百万或数十亿,那么存储它可能是不可能的。

幸运的是,jax.grad() 已经为我们提供了一种编写高效 Hessian 向量积函数的方法。你只需要利用恒等式:

\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),

其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\)\(x\) 处的梯度与向量 \(v\) 点积。请注意,你只是在对向量参数的标量值函数进行微分,这正是你知道 jax.grad() 高效的地方。

在 JAX 代码中,你可以直接写成:

def hvp(f, x, v):
    return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x)

这个例子表明,你可以自由地使用词法闭包,而 JAX 永远不会因此感到困惑。

你将在下方几个单元格中检查此实现,一旦你学会了如何计算稠密 Hessian 矩阵。你还将编写一个使用前向模式和反向模式的更好版本。

使用 jax.jacfwdjax.jacrev 计算雅可比矩阵和 Hessian 矩阵#

你可以使用 jax.jacfwd()jax.jacrev() 函数计算完整的雅可比矩阵。

from jax import jacfwd, jacrev

# Define a sigmoid function.
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])

# Initialize random model coefficients
key = jax.random.key(0)
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415  0.1091874   0.07506633]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447992]
 [ 0.00574409 -0.0193281   0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415  0.10918741  0.07506634]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447995]
 [ 0.00574409 -0.0193281   0.01078958]]

这两个函数计算相同的值(在机器数值精度范围内),但在实现上有所不同:jax.jacfwd() 使用前向模式自动微分,这对于“高”雅可比矩阵(输出多于输入)更高效,而 jax.jacrev() 使用反向模式,这对于“宽”雅可比矩阵(输入多于输出)更高效。对于近乎方形的矩阵,jax.jacfwd() 可能比 jax.jacrev() 略有优势。

你也可以在容器类型中使用 jax.jacfwd()jax.jacrev()

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jax.jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)
Jacobian from W to logits is
[[ 0.05069415  0.10918741  0.07506634]
 [ 0.14170025 -0.17390487  0.02415345]
 [ 0.12579198  0.01451446 -0.31447995]
 [ 0.00574409 -0.0193281   0.01078958]]
Jacobian from b to logits is
[0.09748876 0.16102302 0.24190766 0.00776229]

有关前向和反向模式的更多详细信息,以及如何尽可能高效地实现 jax.jacfwd()jax.jacrev(),请继续阅读!

使用其中两个函数的组合,我们就可以计算稠密 Hessian 矩阵。

def hessian(f):
    return jax.jacfwd(jax.jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932  0.04434624  0.03048803]
  [ 0.04434623  0.09551499  0.06566654]
  [ 0.03048803  0.06566655  0.04514575]]

 [[-0.0743913   0.09129842 -0.01268033]
  [ 0.09129842 -0.11204806  0.01556223]
  [-0.01268034  0.01556223 -0.00216142]]

 [[ 0.01176856  0.00135791 -0.02942139]
  [ 0.00135791  0.00015668 -0.00339478]
  [-0.0294214  -0.00339478  0.07355348]]

 [[-0.00418412  0.014079   -0.00785936]
  [ 0.014079   -0.04737393  0.02644569]
  [-0.00785936  0.02644569 -0.01476286]]]

这种形状是有道理的:如果你从函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,你期望得到以下形状:

  • \(f(x) \in \mathbb{R}^m\),即 \(f\)\(x\) 处的值,

  • \(\partial f(x) \in \mathbb{R}^{m \times n}\),即 \(x\) 处的雅可比矩阵,

  • \(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),即 \(x\) 处的 Hessian 矩阵,

依此类推。

要实现 hessian,你可以使用 jacfwd(jacrev(f))jacrev(jacfwd(f)) 或这两者的任何其他组合。但前向-反向(forward-over-reverse)通常最有效。这是因为在内部雅可比矩阵计算中,我们通常是对宽雅可比函数(例如损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\))进行微分,而在外部雅可比矩阵计算中,我们是对具有方形雅可比矩阵的函数进行微分(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这正是前向模式胜出的地方。