高级自动微分#
在本教程中,您将学习 JAX 中自动微分 (autodiff) 的复杂应用,并更深入地理解 JAX 中的导数计算是如何既简单又强大的。
请务必查看 自动微分 教程,回顾 JAX 自动微分的基础知识,如果您还没有看过的话。
设置#
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
计算梯度 (第 2 部分)#
高阶导数#
JAX 的自动微分可以轻松计算高阶导数,因为计算导数的函数本身是可微的。因此,高阶导数就像堆叠变换一样简单。
单变量情况已在 自动微分 教程中介绍,该教程示例展示了如何使用 jax.grad()
计算 \(f(x) = x^3 + 2x^2 - 3x + 1\) 的导数。
在多变量情况下,高阶导数更加复杂。一个函数的二阶导数由其 Hessian 矩阵 表示,定义如下:
一个多元实值函数 \(f: \mathbb R^n\to\mathbb R\) 的 Hessian 可以被识别为其梯度的 Jacobian。
JAX 提供了两个用于计算函数 Jacobian 的变换:jax.jacfwd()
和 jax.jacrev()
,分别对应前向模式和反向模式自动微分。它们得到相同的结果,但一种在不同情况下可能比另一种更有效 — 请参阅 关于自动微分的视频。
def hessian(f):
return jax.jacfwd(jax.grad(f))
让我们仔细检查一下点积 \(f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}\) 的正确性。
如果 \(i=j\),则 \(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2\)。否则,\(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0\)。
def f(x):
return jnp.dot(x, x)
hessian(f)(jnp.array([1., 2., 3.]))
Array([[2., 0., 0.],
[0., 2., 0.],
[0., 0., 2.]], dtype=float32)
高阶优化#
一些元学习技术,例如模型无关元学习 (Model-Agnostic Meta-Learning, MAML),需要对梯度更新进行微分。在其他框架中这可能相当麻烦,但在 JAX 中则容易得多。
def meta_loss_fn(params, data):
"""Computes the loss after one step of SGD."""
grads = jax.grad(loss_fn)(params, data)
return loss_fn(params - lr * grads, data)
meta_grads = jax.grad(meta_loss_fn)(params, data)
停止梯度#
自动微分能够自动计算函数相对于其输入的梯度。然而,有时您可能需要额外的控制:例如,您可能希望避免通过计算图的某个子集反向传播梯度。
以 TD(0) (时间差分) 强化学习更新为例。这用于学习估计环境中状态的值,基于与环境交互的经验。假设一个状态 \(s_{t-1}\) 的值估计 \(v_{\theta}(s_{t-1}\)) 由线性函数参数化。
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])
考虑一个从状态 \(s_{t-1}\) 到状态 \(s_t\) 的转换,在此期间您观察到了奖励 \(r_t\)。
# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])
网络参数的 TD(0) 更新为:
此更新不是任何损失函数的梯度。
然而,它可以被写成伪损失函数
的梯度,如果忽略目标 \(r_t + v_{\theta}(s_t)\) 对参数 \(\theta\) 的依赖性。
如何在 JAX 中实现这一点?如果您朴素地编写伪损失函数,您将得到:
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((target - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
Array([-1.2, 1.2, -1.2], dtype=float32)
但是 td_update
将不会计算 TD(0) 更新,因为梯度计算将包含 target
对 \(\theta\) 的依赖性。
您可以使用 jax.lax.stop_gradient()
来强制 JAX 忽略 target
对 \(\theta\) 的依赖性。
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
Array([ 1.2, 2.4, -1.2], dtype=float32)
这将把 target
视为不依赖于参数 \(\theta\),并计算参数的正确更新。
现在,让我们使用原始 TD(0) 更新表达式来计算 \(\Delta \theta\),以交叉检查我们的工作。您可能希望尝试使用 jax.grad()
和您迄今为止的知识来实现这一点。这是我们的解决方案:
s_grad = jax.grad(value_fn)(theta, s_tm1)
delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad
delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`
Array([ 1.2, 2.4, -1.2], dtype=float32)
jax.lax.stop_gradient
也可能在其他场景下有用,例如,如果您希望某个损失的梯度仅影响神经网络参数的子集(因为,例如,其他参数是使用不同的损失进行训练的)。
使用 stop_gradient
的直通估计器#
直通估计器是一种为否则不可微的函数定义“梯度”的技巧。给定一个不可微函数 \(f : \mathbb{R}^n \to \mathbb{R}^n\),该函数是我们要找到其梯度的较大函数的一部分,我们在反向传播时简单地假装 \(f\) 是恒等函数。这可以使用 jax.lax.stop_gradient
巧妙地实现。
def f(x):
return jnp.round(x) # non-differentiable
def straight_through_f(x):
# Create an exactly-zero expression with Sterbenz lemma that has
# an exactly-one gradient.
zero = x - jax.lax.stop_gradient(x)
return zero + jax.lax.stop_gradient(f(x))
print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))
print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x): 3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0
每样本梯度#
虽然大多数机器学习系统为了计算效率和/或方差减少而计算批量数据的梯度和更新,但有时需要访问批次中每个特定样本的梯度/更新。
例如,这对于基于梯度幅度优先排序数据,或逐样本应用裁剪/归一化是必需的。
在许多框架 (PyTorch, TF, Theano) 中,计算每样本梯度通常并不简单,因为库会直接累积批次梯度。朴素的解决方法,例如为每个样本计算单独的损失然后聚合得到的梯度,通常效率非常低下。
在 JAX 中,您可以轻松高效地定义计算每样本梯度的代码。
只需将 jax.jit()
、jax.vmap()
和 jax.grad()
变换组合在一起。
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
让我们一次一个变换地进行。
首先,将 jax.grad()
应用于 td_loss
,得到一个计算损失相对于单个 (未批处理) 输入参数的梯度的函数。
dtdloss_dtheta = jax.grad(td_loss)
dtdloss_dtheta(theta, s_tm1, r_t, s_t)
Array([ 1.2, 2.4, -1.2], dtype=float32)
此函数计算上述数组的一行。
然后,使用 jax.vmap()
向量化此函数。这将为所有输入和输出添加一个批次维度。现在,给定一批输入,您会得到一批输出 — 批次中的每个输出对应于输入批次中相应成员的梯度。
almost_perex_grads = jax.vmap(dtdloss_dtheta)
batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
这并非完全是我们想要的,因为我们必须手动为该函数提供一批 theta
,而我们实际上想使用单个 theta
。通过向 jax.vmap()
添加 in_axes
来解决这个问题,将 theta 指定为 None
,其他参数指定为 0
。这使得生成的函数仅为其他参数添加一个额外的轴,而 theta
保持未批处理,正如我们所期望的那样。
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
这实现了我们想要的功能,但比它本应慢。现在,您将整个内容包装在一个 jax.jit()
中,以获得相同函数的编译、高效版本。
perex_grads = jax.jit(inefficient_perex_grads)
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
4.2 ms ± 7.73 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.55 μs ± 10.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
使用 jax.grad
-of-jax.grad
计算 Hessian-向量积#
使用高阶 jax.grad()
可以构建 Hessian-向量积函数。(稍后您将编写一个更有效的实现,它混合了前向模式和反向模式,但这个实现将仅使用反向模式。)
Hessian-向量积函数在 截断牛顿共轭梯度算法 中用于最小化平滑凸函数,或用于研究神经网络训练目标的曲率(例如,1、2、3、4)。
对于一个具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\) (因此 Hessian 矩阵是对称的),在点 \(x \in \mathbb{R}^n\) 处的 Hessian 表示为 \(\partial^2 f(x)\)。那么 Hessian-向量积函数可以计算
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
对于任何 \(v \in \mathbb{R}^n\)。
诀窍在于不实例化完整的 Hessian 矩阵:如果 \(n\) 很大,在神经网络的上下文中可能达到数百万甚至数十亿,那么存储它可能是不可能的。
幸运的是,jax.grad()
已经提供了一种编写高效 Hessian-向量积函数的方法。您只需要使用这个恒等式:
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\) 在 \(x\) 处的梯度与向量 \(v\) 点积。请注意,您只对向量值参数的标量值函数进行微分,这正是我们知道 jax.grad()
高效的地方。
在 JAX 代码中,您可以直接这样写:
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这个例子表明您可以自由使用词法闭包,JAX 永远不会被干扰或混淆。
您将在几行代码后检查此实现,届时您将学会如何计算稠密的 Hessian 矩阵。您还将编写一个更好的版本,它同时使用前向模式和反向模式。
使用 jax.jacfwd
和 jax.jacrev
计算 Jacobian 和 Hessian#
您可以使用 jax.jacfwd()
和 jax.jacrev()
函数计算完整的 Jacobian 矩阵。
from jax import jacfwd, jacrev
# Define a sigmoid function.
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05069415 0.1091874 0.07506633]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447992]
[ 0.00574409 -0.0193281 0.01078958]]
jacrev result, with shape (4, 3)
[[ 0.05069415 0.10918741 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
这两个函数计算的值相同 (在机器精度范围内),但在实现上有所不同:jax.jacfwd()
使用前向模式自动微分,对于“高瘦”的 Jacobian 矩阵(输出比输入多)更有效,而 jax.jacrev()
使用反向模式,对于“宽胖”的 Jacobian 矩阵(输入比输出多)更有效。对于接近正方形的矩阵,jax.jacfwd()
可能比 jax.jacrev()
更有优势。
您还可以将 jax.jacfwd()
和 jax.jacrev()
与容器类型一起使用。
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05069415 0.10918741 0.07506634]
[ 0.14170025 -0.17390487 0.02415345]
[ 0.12579198 0.01451446 -0.31447995]
[ 0.00574409 -0.0193281 0.01078958]]
Jacobian from b to logits is
[0.09748876 0.16102302 0.24190766 0.00776229]
有关前向模式和反向模式的更多详细信息,以及如何尽可能高效地实现 jax.jacfwd()
和 jax.jacrev()
,请继续阅读!
组合使用这两种函数之一可以让我们计算稠密的 Hessian 矩阵。
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02058932 0.04434624 0.03048803]
[ 0.04434623 0.09551499 0.06566654]
[ 0.03048803 0.06566655 0.04514575]]
[[-0.0743913 0.09129842 -0.01268033]
[ 0.09129842 -0.11204806 0.01556223]
[-0.01268034 0.01556223 -0.00216142]]
[[ 0.01176856 0.00135791 -0.02942139]
[ 0.00135791 0.00015668 -0.00339478]
[-0.0294214 -0.00339478 0.07355348]]
[[-0.00418412 0.014079 -0.00785936]
[ 0.014079 -0.04737393 0.02644569]
[-0.00785936 0.02644569 -0.01476286]]]
这个形状是有意义的:如果您从一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\),您期望得到以下形状:
\(f(x) \in \mathbb{R}^m\),\(f\) 在 \(x\) 处的值,
\(\partial f(x) \in \mathbb{R}^{m \times n}\),在 \(x\) 处的 Jacobian 矩阵,
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),在 \(x\) 处的 Hessian,
依此类推。
要实现 hessian
,您可以使用 jacfwd(jacrev(f))
或 jacrev(jacfwd(f))
或这些函数的任何其他组合。但前向套反向通常是最有效的。这是因为在内部 Jacobian 计算中,我们通常在微分一个宽 Jacobian 的函数(可能像一个损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\)),而在外部 Jacobian 计算中,我们是在微分一个具有方形 Jacobian 的函数(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),而这正是前向模式占优的领域。
它是如何制作的:两个基础的自动微分函数#
Jacobian-向量积 (JVP,也称前向模式自动微分)#
JAX 包含了前向模式和反向模式自动微分的高效通用实现。熟悉的 jax.grad()
函数是基于反向模式构建的,但要解释两种模式之间的区别以及何时每种模式有用,您需要一些数学背景。
JVP 数学原理#
在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\),在输入点 \(x \in \mathbb{R}^n\) 处计算的 \(f\) 的 Jacobian,记为 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的一个矩阵。
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
但您也可以将 \(\partial f(x)\) 看作一个线性映射,它将 \(f\) 在点 \(x\) 处的切空间(就是另一个 \(\mathbb{R}^n\) 的副本)映射到 \(f\) 的对偶空间在点 \(f(x)\) 处的切空间(\(\mathbb{R}^m\) 的副本)。
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
这个映射被称为 \(f\) 在 \(x\) 处的推前映射 (pushforward map)。Jacobian 矩阵就是这个线性映射在标准基上的矩阵。
如果您不确定一个特定的输入点 \(x\),那么您可以将函数 \(\partial f\) 视为先接受一个输入点,然后返回该输入点处的 Jacobian 线性映射。
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
特别地,您可以解包,这样给定输入点 \(x \in \mathbb{R}^n\) 和切向量 \(v \in \mathbb{R}^n\),您将得到一个输出切向量 \(\mathbb{R}^m\)。我们将从 \((x, v)\) 对到输出切向量的映射称为Jacobian-向量积,并将其表示为
\(\qquad (x, v) \mapsto \partial f(x) v\)
JVP JAX 代码实现#
回到 Python 代码,JAX 的 jax.jvp()
函数模拟了这个变换。给定一个计算 \(f\) 的 Python 函数,JAX 的 jax.jvp()
是一种获取一个 Python 函数来计算 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的方法。
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
用 类似 Haskell 的类型签名 来表示,您可以写成:
jvp :: (a -> b) -> a -> T a -> (b, T b)
其中 T a
用于表示 a
的切空间的类型。
换句话说,jvp
接收的参数是一个类型为 a -> b
的函数,一个类型为 a
的值,和一个类型为 T a
的切向量值。它返回一个对,包含一个类型为 b
的值和一个类型为 T b
的输出切向量。
jvp
转换后的函数评估方式与原始函数类似,但配对每个原始值 a
,它会传递类型为 T a
的切向量。对于原始函数会应用到的每一种原始数值操作,jvp
转换后的函数会执行该原始操作的“JVP 规则”,该规则同时在原始值上评估该原始操作,并将其线性化版本应用于向量。
这种评估策略对计算复杂度有一些直接影响。由于我们边评估边计算 JVP,所以我们不需要存储任何东西供以后使用,因此内存成本与计算深度无关。此外,jvp
转换后的函数的 FLOP 成本大约是仅评估该函数的成本的 3 倍(例如,评估 sin(x)
需要一个工作单元;线性化,如 cos(x)
,需要一个工作单元;以及将线性化函数应用于向量,如 cos_x * v
,需要一个工作单元)。换句话说,对于一个固定的原始点 \(x\),我们可以以与评估 \(f\) 差不多的边际成本来评估 \(v \mapsto \partial f(x) \cdot v\)。
这种内存复杂度听起来很有吸引力!那么为什么我们在机器学习中很少看到前向模式呢?
要回答这个问题,首先想想如何使用 JVP 来构建一个完整的 Jacobian 矩阵。如果我们对一个单热切向量应用 JVP,它会揭示 Jacobian 矩阵的一列,对应于我们输入的非零项。因此,我们可以一次构建一个完整的 Jacobian 列,并且获取每一列的成本与一次函数评估的成本差不多。这对于输出比输入多的 Jacobian (即“高瘦”的 Jacobian) 来说是有效的,但对于输入比输出多的 Jacobian (即“宽胖”的 Jacobian) 则效率低下。
如果您在机器学习中进行基于梯度的优化,您可能希望最小化一个从 \(\mathbb{R}^n\) 中的参数到标量损失值 \(\mathbb{R}\) 的损失函数。这意味着该函数的 Jacobian 是一个非常宽的矩阵: \(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其识别为梯度向量 \(\nabla f(x) \in \mathbb{R}^n\)。一次构建一个列,每次调用都需要与评估原始函数相似数量的 FLOPs,这肯定显得效率低下!特别地,对于训练神经网络,其中 \(f\) 是一个训练损失函数,而 \(n\) 可能达到数百万甚至数十亿,这种方法根本无法扩展。
要改进这类函数,您只需要使用反向模式。
向量-Jacobian 积 (VJP,也称反向模式自动微分)#
前向模式为我们提供了计算 Jacobian-向量积的函数,然后我们可以用它来一次一个列地构建 Jacobian 矩阵。反向模式是一种获取计算向量-Jacobian 积(等效于 Jacobian-转置-向量积)的函数的方法,我们可以用它来一次一个行地构建 Jacobian 矩阵。
VJP 数学原理#
我们再次考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从 JVP 的符号开始,VJP 的符号很简单:
\(\qquad (x, v) \mapsto v \partial f(x)\),
其中 \(v\) 是 \(f\) 在 \(x\) 处的余切空间(同构于另一个 \(\mathbb{R}^m\) 的副本)的一个元素。在严谨的表述中,我们应该将 \(v\) 视为一个线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),当我们写 \(v \partial f(x)\) 时,我们指的是函数组合 \(v \circ \partial f(x)\),其中类型是匹配的,因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见情况下,我们可以将 \(v\) 与 \(\mathbb{R}^m\) 中的向量识别,并几乎不加区分地使用它们,就像我们有时会在“列向量”和“行向量”之间切换而没有太多评论一样。
有了这个识别,我们可以将 VJP 的线性部分视为 JVP 的线性部分的转置(或伴随共轭)。
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
对于给定的点 \(x\),我们可以写出其签名:
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
对应的余切空间上的映射通常被称为 \(f\) 在 \(x\) 处的拉回映射 (pullback)。对我们来说,关键在于它从看起来像 \(f\) 输出的东西映射到一个看起来像 \(f\) 输入的东西,就像我们从一个转置的线性函数中期望的那样。
VJP JAX 代码实现#
从数学切换回 Python,JAX 函数 vjp
可以接收一个计算 \(f\) 的 Python 函数,并返回一个计算 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
用 类似 Haskell 的类型签名 来表示,我们可以写成:
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
其中我们使用 CT a
来表示 a
的余切空间的类型。用通俗的话来说,vjp
接收的参数是一个类型为 a -> b
的函数和一个类型为 a
的点,并返回一个对,包含一个类型为 b
的值和一个类型为 CT b -> CT a
的线性映射。
这很棒,因为它允许我们一次一个行地构建 Jacobian 矩阵,并且计算 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOPs 成本仅为计算 \(f\) 的成本的约三倍。特别地,如果我们想要一个函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们可以在一次调用中完成。这就是 jax.grad()
对基于梯度的优化高效的原因,即使是对于像神经网络训练损失函数这样具有数百万或数十亿参数的目标。
不过,这也有成本:虽然 FLOPs 友好,但内存会随着计算深度的增加而扩展。此外,其实现通常比前向模式复杂,尽管 JAX 有一些绝招(这将在未来的笔记本中讲述!)。
有关反向模式工作原理的更多信息,请查看 2017 年深度学习暑期学校的本次教程视频。
使用 VJP 计算向量值梯度#
如果您对计算向量值梯度(例如 tf.gradients
)感兴趣。
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
[6. 6.]]
使用前向和反向模式计算 Hessian-向量积#
在前面的部分中,您仅使用反向模式实现了一个 Hessian-向量积函数(假设二阶导数连续)。
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这是有效的,但您可以通过结合使用前向模式和反向模式做得更好,并节省一些内存。
在数学上,给定一个要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\),一个要线性化的点 \(x \in \mathbb{R}^n\),以及一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量积函数是:
\((x, v) \mapsto \partial^2 f(x) v\)
考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),它定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。您只需要它的 JVP,因为它将为您提供:
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
我们几乎可以直接将其翻译成代码:
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
甚至更好的是,由于您不必直接调用 jnp.dot()
,这个 hvp
函数可以与任何形状的数组以及任意容器类型(如存储为嵌套列表/字典/元组的向量)一起工作,并且甚至不依赖于 jax.numpy
。
这是一个如何使用它的示例:
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
您还可以考虑使用反向模式套前向模式来编写此代码:
# Reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
这并不完全好,因为前向模式的开销比反向模式小,并且由于外部微分算子必须微分比内部算子更大的计算量,因此将前向模式放在外面效果最好。
# Reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
3.12 ms ± 71.1 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
The slowest run took 6.91 times longer than the fastest. This could mean that an intermediate result is being cached.
14.2 ms ± 13.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
The slowest run took 4.40 times longer than the fastest. This could mean that an intermediate result is being cached.
17 ms ± 12.7 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
34.6 ms ± 1.06 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
组合 VJP、JVP 和 jax.vmap
#
Jacobian-矩阵 和 矩阵-Jacobian 积#
既然您已经有了 jax.jvp()
和 jax.vjp()
变换,它们为您提供了逐个向量地推前或拉回的函数,您可以使用 JAX 的 jax.vmap()
变换 来一次性推拉整个基。特别是,您可以使用它来编写快速的矩阵-Jacobian 和 Jacobian-矩阵积。
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
46.3 ms ± 60.3 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
3.16 ms ± 52.4 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1332/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
88.3 ms ± 10.1 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
1.4 ms ± 43.8 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
jax.jacfwd
和 jax.jacrev
的实现#
既然我们已经看到了快速的 Jacobian-矩阵和矩阵-Jacobian 积,那么要猜出如何编写 jax.jacfwd()
和 jax.jacrev()
并不难。我们只需使用相同的技术一次性推前或拉回整个标准基(同构于单位矩阵)。
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
有趣的是,Autograd 库无法做到这一点。Autograd 中反向模式 jacobian
的实现必须通过外部循环 map
来逐个向量地拉回。一次通过计算推前一个向量比使用 jax.vmap()
将所有内容一起批处理效率低得多。
Autograd 另一个无法做到的事情是 jax.jit()
。有趣的是,无论您在要微分的函数中使用多少 Python 动态性,我们总是可以使用 jax.jit()
来处理计算的线性部分。例如:
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
复数与微分#
JAX 在复数和微分方面表现出色。为了同时支持全纯和非全纯微分,最好从 JVP 和 VJP 的角度来考虑。
考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\) 并将其识别为对应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\):
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
也就是说,我们将 \(f(z) = u(x, y) + v(x, y) i\) 分解,其中 \(z = x + y i\),并将 \(\mathbb{C}\) 识别为 \(\mathbb{R}^2\) 以得到 \(g\)。
由于 \(g\) 只涉及实数输入和输出,我们已经知道如何为它编写 Jacobian-向量积,例如,给定一个切向量 \((c, d) \in \mathbb{R}^2\),即:
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
为了得到应用于切向量 \(c + di \in \mathbb{C}\) 的原始函数 \(f\) 的 JVP,我们只需使用相同的定义并将结果识别为另一个复数:
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数的 JVP 的定义!请注意,\(f\) 是否全纯并不重要:JVP 是明确的。
这是一个检查:
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
VJP 呢?我们做类似的事情:对于一个余切向量 \(c + di \in \mathbb{C}\),我们将 \(f\) 的 VJP 定义为:
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).
负号是怎么回事?它们只是用来处理复共轭,以及我们正在处理余向量的事实。
这是一个 VJP 规则的检查:
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
像 jax.grad()
、jax.jacfwd()
和 jax.jacrev()
这样的便利包装器呢?
对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下我们将 grad(f)(x)
定义为 vjp(f, x)[1](1.0)
,这是有效的,因为将 VJP 应用于 1.0
值会揭示梯度(即 Jacobian,或导数)。我们对 \(\mathbb{C} \to \mathbb{R}\) 函数也可以做同样的事情:我们仍然可以使用 1.0
作为余切向量,我们得到一个复数结果来总结整个 Jacobian。
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
对于通用的 \(\mathbb{C} \to \mathbb{C}\) 函数,Jacobian 有 4 个实值自由度(如上面的 2x2 Jacobian 矩阵所示),所以我们无法在一个复数中表示所有这些自由度。但对于全纯函数,我们可以!全纯函数本质上是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,它具有一个特殊属性,即其导数可以表示为一个单一的复数。(柯西-黎曼方程确保上面的 2x2 Jacobian 具有复平面中尺度旋转矩阵的形式,即通过乘法作用于一个复数。)我们可以通过一次调用 vjp
并使用 1.0
作为余切向量来揭示这个单一的复数。
由于这只对全纯函数有效,为了使用这个技巧,我们需要向 JAX 承诺我们的函数是全纯的;否则,当 jax.grad()
用于复数输出函数时,JAX 将会抛出错误。
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)
所有 holomorphic=True
的承诺只是在输出为复值时禁用错误。当函数不是全纯的时候,我们仍然可以写 holomorphic=True
,但我们得到的结果将不代表完整的 Jacobian。相反,它将是函数 Jacobian 的结果,其中我们只是丢弃了输出的虚部。
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
这里有一些关于 jax.grad()
如何工作的有用推论:
我们可以对全纯 \(\mathbb{C} \to \mathbb{C}\) 函数使用
jax.grad()
。我们可以使用
jax.grad()
来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复参数x
的实值损失函数,通过沿grad(f)(x)
的共轭方向进行步进。如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,它只是碰巧在内部使用了某些复数值操作(其中一些必须是非全纯的,例如卷积中使用的 FFT),那么
jax.grad()
仍然有效,我们得到的结果与仅使用实数值的实现相同。
无论如何,JVP 和 VJP 始终是明确的。而且,如果我们想计算一个非全纯 \(\mathbb{C} \to \mathbb{C}\) 函数的完整 Jacobian 矩阵,我们可以使用 JVP 或 VJP 来完成!
您应该期望复数在 JAX 的任何地方都能正常工作。下面是一个通过复数矩阵的 Cholesky 分解进行微分的例子。
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
为 JAX 可转换的 Python 函数定义自定义导数规则#
在 JAX 中定义微分规则有两种方式:
使用
jax.custom_jvp()
和jax.custom_vjp()
为已经是 JAX 可转换的 Python 函数定义自定义微分规则;以及定义新的
core.Primitive
实例及其所有变换规则,例如调用其他系统的函数,如求解器、模拟器或通用数值计算系统。
本笔记本是关于 #1 的。要阅读关于 #2 的内容,请参阅关于添加原语的笔记本。
简而言之:使用 jax.custom_jvp()
定义自定义 JVP#
from jax import custom_jvp
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
# Equivalent alternative using the `defjvps` convenience wrapper
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
简而言之:使用 jax.custom_vjp
定义自定义 VJP#
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by `f_bwd`.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in `f_fwd`
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
示例问题#
为了了解 jax.custom_jvp()
和 jax.custom_vjp()
旨在解决哪些问题,让我们回顾几个例子。对 jax.custom_jvp()
和 jax.custom_vjp()
API 的更全面介绍在下一节中。
示例:数值稳定性#
jax.custom_jvp()
的一个应用是提高微分的数值稳定性。
假设我们要编写一个名为 log1pexp
的函数,它计算 \(x \mapsto \log ( 1 + e^x )\)。我们可以使用 jax.numpy
来编写它:
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)
由于它是用 jax.numpy
编写的,所以它是 JAX 可转换的:
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
但这里潜藏着一个数值稳定性问题:
print(grad(log1pexp)(100.))
nan
这似乎不对!毕竟,\(x \mapsto \log (1 + e^x)\) 的导数是 \(x \mapsto \frac{e^x}{1 + e^x}\),所以对于大的 \(x\) 值,我们期望其值约为 1。
我们可以通过查看梯度计算的 jaxpr 来获得更深入的了解。
from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.)
{ lambda ; a:f32[]. let
b:f32[] = exp a
c:f32[] = add 1.0:f32[] b
_:f32[] = log c
d:f32[] = div 1.0:f32[] c
e:f32[] = mul d b
in (e,) }
逐步评估 jaxpr,注意最后一行将涉及将浮点数 math 计算结果四舍五入到 0 和 \(\infty\) 的值,这永远不是个好主意。也就是说,对于大的 x
,我们实际上是在计算 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)
,这实际上变成了 0. * jnp.inf
。
与其生成如此大和小的数,期望浮点数无法始终提供的精确抵消,我们宁愿将导数函数表示为一个更数值稳定的程序。特别是,我们可以编写一个程序,更紧密地评估数学上相等的表达式 \(1 - \frac{1}{1 + e^x}\),而没有任何抵消。
这个问题很有趣,因为尽管我们对 log1pexp
的定义已经可以 JAX 微分(并使用 jax.jit()
、jax.vmap()
等进行变换),但我们对将标准自动微分规则应用于构成 log1pexp
的原语并组合结果不满意。相反,我们希望指定整个函数 log1pexp
应该如何被微分,作为一个整体,从而更好地安排指数的计算。
这是自定义 JAX 可转换 Python 函数导数规则的一个应用:指定一个复合函数应该如何微分,同时仍然为其他变换(如 jax.jit()
、jax.vmap()
等)使用其原始 Python 定义。
这是一个使用 jax.custom_jvp()
的解决方案:
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = log1pexp(x)
ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
return ans, ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
这是一个 defjvps
便利包装器,用于表达相同的内容:
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
示例:强制执行微分约定#
一个相关的应用是强制执行微分约定,也许是在边界处。
考虑函数 \(f : \mathbb{R}_+ \to \mathbb{R}_+\),其中 \(f(x) = \frac{x}{1 + \sqrt{x}}\),其中 \(\mathbb{R}_+ = [0, \infty)\)。我们可以像这样实现 \(f\) 的程序:
def f(x):
return x / (1 + jnp.sqrt(x))
作为一个在 \(\mathbb{R}\)(整个实数轴)上的数学函数,\(f\) 在零点不可微(因为定义导数的极限从左侧不存在)。相应地,自动微分会产生一个 nan
值:
print(grad(f)(0.))
nan
但在数学上,如果我们把 \(f\) 看作一个在 \(\mathbb{R}_+\) 上的函数,那么它在 0 点是可微的 [Rudin 的《数学分析原理》定义 5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和示例 10.1.6]。或者,我们可以说,按照约定,我们想要考虑从右侧的方向导数。因此,Python 函数 grad(f)
在 0.0
处返回一个有意义的值,即 1.0
。默认情况下,JAX 的微分机制假设所有函数都定义在 \(\mathbb{R}\) 上,因此在这里不会产生 1.0
。
我们可以使用自定义 JVP 规则!特别是,我们可以根据函数 \(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) 在 \(\mathbb{R}_+\) 上的导数来定义 JVP 规则:
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
return ans, ans_dot
print(grad(f)(0.))
1.0
这是便利包装器版本:
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
print(grad(f)(0.))
1.0
示例:梯度裁剪#
虽然在某些情况下我们希望表达一个数学上的微分计算,但在其他情况下,我们甚至可能想要暂时偏离数学,来调整自动微分执行的计算。一个典型的例子是反向模式梯度裁剪。
对于梯度裁剪,我们可以使用 jnp.clip()
结合一个 jax.custom_vjp()
仅反向模式规则:
from functools import partial
@custom_vjp
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save bounds as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
import matplotlib.pyplot as plt
t = jnp.linspace(0, 10, 1000)
plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x79117940a780>]

def clip_sin(x):
x = clip_gradient(-0.75, 0.75, x)
return jnp.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x79117949e390>]

示例:Python 调试#
另一个由开发流程而非数值驱动的应用是在反向模式自动微分的反向传播中设置一个 pdb
调试器断点。
在试图找出 nan
运行时错误的根源,或者只是仔细检查正在传播的余切(梯度)值时,将调试器插入到对应于原始计算特定点的反向传播点可能很有用。您可以使用 jax.custom_vjp()
来做到这一点。
我们将在下一节中推迟一个例子。
示例:迭代实现的隐函数微分#
这个例子深入到数学的细节!
jax.custom_vjp()
的另一个应用是反向模式微分对于 JAX 可转换(通过 jax.jit()
、jax.vmap()
等)但由于某种原因无法高效 JAX 微分的函数,可能是因为它们涉及 jax.lax.while_loop()
。(不可能生成一个 XLA HLO 程序来高效计算 XLA HLO While 循环的反向模式导数,因为这需要一个内存无界限的程序,而这无法在 XLA HLO 中表达,至少没有通过 infeed/outfeed 的“副作用”交互)。
例如,考虑这个 fixed_point
例程,它通过在 while_loop
中迭代应用一个函数来计算不动点:
from jax.lax import while_loop
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
这是一个数值求解方程 \(x = f(a, x)\) 的迭代过程,通过迭代 \(x_{t+1} = f(a, x_t)\) 直到 \(x_{t+1}\) 非常接近 \(x_t\)。结果 \(x^*\) 取决于参数 \(a\),所以我们可以认为存在一个函数 \(a \mapsto x^*(a)\),它是由方程 \(x = f(a, x)\) 隐式定义的。
我们可以使用 fixed_point
来运行迭代过程直到收敛,例如,只执行加法、乘法和除法来运行牛顿法计算平方根:
def newton_sqrt(a):
update = lambda a, x: 0.5 * (x + a / x)
return fixed_point(update, a, a)
print(newton_sqrt(2.))
1.4142135
我们也可以将该函数 jax.vmap()
或 jax.jit()
。
print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1. 1.4142135 1.7320509 2. ]
由于 while_loop
,我们无法应用反向模式自动微分,但实际上我们也不想这样做:与其微分 fixed_point
的实现及其所有迭代,不如利用数学结构来做一些内存效率更高(并且在这种情况下,FLOPs 效率也更高!)的事情。我们可以改用隐函数定理 [Bertsekas 的非线性规划,第 2 版 A.25 命题],它保证(在某些条件下)我们即将使用的数学对象的存在。本质上,我们线性化解并迭代求解这些线性方程来计算我们想要的导数。
再次考虑方程 \(x = f(a, x)\) 和函数 \(x^*\)。我们想计算向量-Jacobian 积,例如 \(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\)。
至少在我们要微分的点 \(a_0\) 的一个开邻域内,假设方程 \(x^*(a) = f(a, x^*(a))\) 对所有 \(a\) 都成立。由于两边是 \(a\) 的函数,因此相等,它们的导数也必须相等,所以让我们对两边进行微分:
\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).
令 \(A = \partial_1 f(a_0, x^*(a_0))\) 和 \(B = \partial_0 f(a_0, x^*(a_0))\),我们可以更简单地写出我们想要的量:
\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),
或者,通过重新排列:
\(\qquad \partial x^*(a_0) = (I - A)^{-1} B\).
这意味着我们可以计算向量-Jacobian 积,例如:
\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B\),
其中 \(w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}\),或者等价地 \(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\),或者等价地 \(w^\mathsf{T}\) 是映射 \(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\) 的不动点。最后一个特征描述为我们提供了一种使用对 fixed_point
的调用来编写 fixed_point
的 VJP 的方法!此外,展开 \(A\) 和 \(B\) 后,您可以得出结论,只需要在 \((a_0, x^*(a_0))\) 处评估 \(f\) 的 VJP。
结果如下:
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
def fixed_point_fwd(f, a, x_init):
x_star = fixed_point(f, a, x_init)
return x_star, (a, x_star)
def fixed_point_rev(f, res, x_star_bar):
a, x_star = res
_, vjp_a = vjp(lambda a: f(a, x_star), a)
a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, jnp.zeros_like(x_star)
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
return x_star_bar + vjp_x(u)[0]
fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
-0.088388346
我们可以通过对 jnp.sqrt()
进行微分来检查我们的答案,它使用了完全不同的实现:
print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.))
0.35355338
-0.08838835
这种方法的一个限制是参数 f
不能闭包任何涉及微分的值。也就是说,您可能会注意到我们将参数 a
保留在 fixed_point
的参数列表中。对于这种情况,请考虑使用低级原语 lax.custom_root
,它允许对具有自定义根查找函数的闭包变量进行微分。
jax.custom_jvp
和 jax.custom_vjp
API 的基本用法#
使用 jax.custom_jvp
定义前向模式(以及间接的反向模式)规则#
这是一个使用 jax.custom_jvp()
的经典基本示例,其中注释使用了 类似 Haskell 的类型签名:
# f :: a -> b
@custom_jvp
def f(x):
return jnp.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
f.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
print(f(3.))
y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
-0.9899925
换句话说,我们从一个接受类型为 a
的输入并产生类型为 b
的输出的原始函数 f
开始。我们为其关联一个 JVP 规则函数 f_jvp
,该函数接受一对代表类型为 a
的原始输入的输入和对应类型为 T a
的切线输入,并产生一对代表类型为 b
的原始输出和类型为 T b
的切线输出。切线输出应是切线输入的线性函数。
您也可以使用 f.defjvp
作为装饰器,如下所示:
@custom_jvp
def f(x):
...
@f.defjvp
def f_jvp(primals, tangents):
...
尽管我们只定义了 JVP 规则而没有定义 VJP 规则,但我们仍然可以对 f
使用前向模式和反向模式微分。JAX 会自动转置我们自定义 JVP 规则上切线值的线性计算,从而像我们手动编写规则一样高效地计算 VJP。
print(grad(f)(3.))
print(grad(grad(f))(3.))
-0.9899925
-0.14112
为了使自动转置生效,JVP 规则的输出切线必须是输入切线的线性函数。否则将引发转置错误。
多个参数的工作方式如下:
@custom_jvp
def f(x, y):
return x ** 2 * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
return primal_out, tangent_out
print(grad(f)(2., 3.))
12.0
便利包装器 defjvps
允许我们单独为每个参数定义 JVP,并将结果分别计算然后求和。
@custom_jvp
def f(x):
return jnp.sin(x)
f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
print(grad(f)(3.))
-0.9899925
这是一个具有多个参数的 defjvps
示例:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0
作为一种简写,使用 defjvps
,您可以传递 None
值来表示特定参数的 JVP 为零。
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
None)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0
使用关键字参数调用 jax.custom_jvp()
函数,或编写带有默认参数的 jax.custom_jvp()
函数定义,都是允许的,只要它们可以根据标准库 inspect.signature
机制检索到的函数签名进行无歧义的映射到位置参数。
当您不执行微分时,函数 f
的调用方式与未被 jax.custom_jvp()
装饰时相同。
@custom_jvp
def f(x):
print('called f!') # a harmless side-effect
return jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
print('called f_jvp!') # a harmless side-effect
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
print(f(3.))
called f!
0.14112
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.))
called f!
[0. 0.84147096 0.9092974 ]
called f!
0.14112
自定义 JVP 规则在微分(前向或反向)期间被调用。
y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
-0.9899925
print(grad(f)(3.))
called f_jvp!
called f!
-0.9899925
请注意,f_jvp
调用 f
来计算原始输出。在更高阶微分的上下文中,每次应用微分变换时,只要该规则调用原始 f
来计算原始输出,该规则就会被使用。(这代表了一种基本的权衡,即我们无法在我们的规则中利用 f
的求值中间值并且 还能使规则在所有更高阶微分的顺序中都适用。)
grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
Array(-0.14112, dtype=float32, weak_type=True)
您可以使用 Python 控制流和 jax.custom_jvp()
。
@custom_jvp
def f(x):
if x > 0:
return jnp.sin(x)
else:
return jnp.cos(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
if x > 0:
return ans, 2 * x_dot
else:
return ans, 3 * x_dot
print(grad(f)(1.))
print(grad(f)(-1.))
2.0
3.0
使用 jax.custom_vjp
来定义仅限反向模式的自定义规则#
虽然 jax.custom_jvp()
足以控制前向模式和(通过 JAX 的自动转置)反向模式微分行为,但在某些情况下,我们可能希望直接控制 VJP 规则,例如在上文介绍的最后两个示例问题中。我们可以使用 jax.custom_vjp()
来实现这一点。
from jax import custom_vjp
# f :: a -> b
@custom_vjp
def f(x):
return jnp.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), jnp.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
print(f(3.))
print(grad(f)(3.))
0.14112
-0.9899925
换句话说,我们再次从一个接受类型为 a
的输入并产生类型为 b
的输出的原始函数 f
开始。我们为其关联两个函数,f_fwd
和 f_bwd
,分别描述了如何执行反向模式自动微分的前向和后向传播。
函数 f_fwd
描述了前向传播,不仅包括原始计算,还包括为了在后向传播中使用而需要保存的值。其输入签名与原始函数 f
相同,即它接受类型为 a
的原始输入。但作为输出,它产生一个对,其中第一个元素是原始输出 b
,第二个元素是为后向传播使用而存储的任何“残差”数据(类型为 c
)。(第二个输出类似于 PyTorch 的 save_for_backward 机制。)
函数 f_bwd
描述了后向传播。它接受两个输入,第一个是 f_fwd
产生的类型为 c
的残差数据,第二个是对应于原始函数输出的类型为 CT b
的输出共切线。它产生一个类型为 CT a
的输出,表示对应于原始函数输入的共切线。特别地,f_bwd
的输出必须是一个序列(例如,一个元组),其长度等于原始函数的参数数量。
所以多个参数的工作方式如下:
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
使用关键字参数调用 jax.custom_vjp()
函数,或编写带有默认参数的 jax.custom_vjp()
函数定义,都是允许的,只要它们可以根据标准库 inspect.signature
机制检索到的函数签名进行无歧义的映射到位置参数。
与 jax.custom_jvp()
类似,由 f_fwd
和 f_bwd
组成的自定义 VJP 规则在不应用微分时不会被调用。如果函数被求值,或被 jax.jit()
、jax.vmap()
或其他非微分变换转换,则只调用 f
。
@custom_vjp
def f(x):
print("called f!")
return jnp.sin(x)
def f_fwd(x):
print("called f_fwd!")
return f(x), jnp.cos(x)
def f_bwd(cos_x, y_bar):
print("called f_bwd!")
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
print(f(3.))
called f!
0.14112
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
-0.9899925
y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
print(f_vjp(1.))
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),)
无法在前向模式自动微分中使用 jax.custom_vjp()
函数,否则将引发错误。
from jax import jvp
try:
jvp(f, (3.,), (1.,))
except TypeError as e:
print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.
如果您想同时使用前向模式和反向模式,请改用 jax.custom_jvp()
。
我们可以将 jax.custom_vjp()
与 pdb
结合使用,在后向传播中插入调试器跟踪。
import pdb
@custom_vjp
def debug(x):
return x # acts like identity
def debug_fwd(x):
return x, x
def debug_bwd(x, g):
import pdb; pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)
def foo(x):
y = x ** 2
y = debug(y) # insert pdb in corresponding backward pass step
return jnp.sin(y)
jax.grad(foo)(3.)
> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q
更多功能和细节#
处理 list
/ tuple
/ dict
容器(和其他 pytrees)#
您应该期望像列表、元组、命名元组和字典这样的标准 Python 容器能够正常工作,以及这些容器的嵌套版本。一般来说,任何 pytrees 都是允许的,只要它们的结构根据类型约束是一致的。
这是一个使用 jax.custom_jvp()
的牵强附会的示例。
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
@custom_jvp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
@f.defjvp
def f_jvp(primals, tangents):
pt, = primals
pt_dot, = tangents
ans = f(pt)
ans_dot = {'a': 2 * pt.x * pt_dot.x,
'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
return ans, ans_dot
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))
以及一个使用 jax.custom_vjp()
的类似的牵强附会的示例。
@custom_vjp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
def f_fwd(pt):
return f(pt), pt
def f_bwd(pt, g):
a_bar, (b0_bar, b1_bar) = g['a'], g['b']
x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
y_bar = -jnp.sin(pt.y) * b1_bar
return (Point(x_bar, y_bar),)
f.defvjp(f_fwd, f_bwd)
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))
处理不可微分的参数#
一些用例,例如最后一个示例问题,需要将函数值参数等不可微分的参数传递给具有自定义微分规则的函数,并且还需要将这些参数也传递给规则。在 fixed_point
的情况下,函数参数 f
就是这样一个不可微分的参数。类似的情况也出现在 jax.experimental.odeint
中。
nondiff_argnums
配合 jax.custom_jvp
#
使用 jax.custom_jvp()
的可选 nondiff_argnums
参数来指示此类参数。这是一个使用 jax.custom_jvp()
的示例。
from functools import partial
@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
@app.defjvp
def app_jvp(f, primals, tangents):
x, = primals
x_dot, = tangents
return f(x), 2. * x_dot
print(app(lambda x: x ** 3, 3.))
27.0
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0
请注意这里的陷阱:无论这些参数在参数列表中出现在哪里,它们都位于相应 JVP 规则的签名**开头**。这是另一个例子:
@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
return f(g((x)))
@app2.defjvp
def app2_jvp(f, g, primals, tangents):
x, = primals
x_dot, = tangents
return f(g(x)), 3. * x_dot
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
nondiff_argnums
配合 jax.custom_vjp
#
jax.custom_vjp()
存在类似选项,同样,约定是不可微分参数作为第一个参数传递给 _bwd
规则,无论它们出现在原始函数的签名中的哪个位置。 _fwd
规则的签名保持不变——它与原始函数的签名相同。这是一个示例:
@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_fwd(f, x):
return f(x), x
def app_bwd(f, x, g):
return (5 * g,)
app.defvjp(app_fwd, app_bwd)
print(app(lambda x: x ** 2, 4.))
16.0
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0
请参考上面的 fixed_point
以获取另一个使用示例。
**您不需要为数组值参数使用** nondiff_argnums
,例如具有整数 dtype 的参数。相反,nondiff_argnums
仅应用于不对应 JAX 类型(本质上不对应数组类型)的参数值,例如 Python 可调用对象或字符串。如果 JAX 检测到 nondiff_argnums
指定的参数包含 JAX Tracer,则会引发错误。上面的 clip_gradient
函数就是一个不为整数 dtype 数组参数使用 nondiff_argnums
的好例子。
下一步#
还有许多其他自动微分技巧和功能。本教程中未涵盖但值得进一步探讨的主题包括:
高斯-牛顿向量积,一次线性化
自定义 VJP 和 JVP
固定点的高效导数
使用随机 Hessian-向量积估计 Hessian 的迹
仅使用反向模式自动微分的前向模式自动微分
对自定义数据类型进行微分
检查点(用于高效反向模式的二项式检查点,而非模型快照)
使用 Jacobian 预累加优化 VJP