高级自动微分#
在本教程中,您将学习 JAX 中自动微分 (autodiff) 的复杂应用,并更好地理解如何在 JAX 中求导既简单又强大。
如果您还没有这样做,请务必查看 自动微分 教程,回顾 JAX 自动微分的基础知识。
设置#
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
求梯度(第 2 部分)#
高阶导数#
JAX 的自动微分使得计算高阶导数变得容易,因为计算导数的函数本身是可微的。因此,高阶导数就像堆叠转换一样简单。
单变量的情况在自动微分教程中已经介绍过,该示例展示了如何使用 jax.grad()
计算 \(f(x) = x^3 + 2x^2 - 3x + 1\) 的导数。
在多变量情况下,高阶导数更为复杂。函数的二阶导数由其 黑塞矩阵 表示,定义如下
多个变量的实值函数 \(f: \mathbb R^n\to\mathbb R\) 的黑塞矩阵可以识别为其梯度的雅可比矩阵。
JAX 提供了两种用于计算函数雅可比矩阵的转换,jax.jacfwd()
和 jax.jacrev()
,分别对应于前向和反向模式自动微分。它们给出相同的答案,但在不同情况下,一种可能比另一种更有效——请参阅关于自动微分的视频。
def hessian(f):
return jax.jacfwd(jax.grad(f))
让我们在点积 \(f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}\) 上仔细检查这是正确的。
如果 \(i=j\),\(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2\)。否则,\(\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0\)。
def f(x):
return jnp.dot(x, x)
hessian(f)(jnp.array([1., 2., 3.]))
Array([[2., 0., 0.],
[0., 2., 0.],
[0., 0., 2.]], dtype=float32)
高阶优化#
一些元学习技术,例如模型无关元学习(MAML),需要通过梯度更新进行微分。在其他框架中,这可能相当麻烦,但在 JAX 中要容易得多
def meta_loss_fn(params, data):
"""Computes the loss after one step of SGD."""
grads = jax.grad(loss_fn)(params, data)
return loss_fn(params - lr * grads, data)
meta_grads = jax.grad(meta_loss_fn)(params, data)
停止梯度#
自动微分可以自动计算函数关于其输入的梯度。但是,有时您可能需要一些额外的控制:例如,您可能希望避免通过计算图的某些子集反向传播梯度。
例如,考虑 TD(0)(时间差分)强化学习更新。这用于从与环境交互的经验中学习估计环境中状态的价值。假设状态 \(s_{t-1}\) 中的价值估计 \(v_{\theta}(s_{t-1}\)) 由线性函数参数化。
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])
考虑从状态 \(s_{t-1}\) 到状态 \(s_t\) 的转换,在此期间您观察到奖励 \(r_t\)
# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])
网络参数的 TD(0) 更新为
此更新不是任何损失函数的梯度。
但是,它可以写成伪损失函数的梯度
如果忽略目标 \(r_t + v_{\theta}(s_t)\) 对参数 \(\theta\) 的依赖性。
如何在 JAX 中实现这一点?如果您天真地编写伪损失,您会得到
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((target - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
Array([-1.2, 1.2, -1.2], dtype=float32)
但是 td_update
将不计算 TD(0) 更新,因为梯度计算将包括 target
对 \(\theta\) 的依赖性。
您可以使用 jax.lax.stop_gradient()
强制 JAX 忽略目标对 \(\theta\) 的依赖性
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
Array([ 1.2, 2.4, -1.2], dtype=float32)
这将把 target
视为它不依赖于参数 \(\theta\),并计算对参数的正确更新。
现在,让我们也使用原始的 TD(0) 更新表达式计算 \(\Delta \theta\),以交叉检查我们的工作。您可能希望尝试使用 jax.grad()
和您目前所掌握的知识自行实现此操作。这是我们的解决方案
s_grad = jax.grad(value_fn)(theta, s_tm1)
delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad
delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`
Array([ 1.2, 2.4, -1.2], dtype=float32)
jax.lax.stop_gradient
也可能在其他设置中很有用,例如,如果您希望某些损失的梯度仅影响神经网络参数的子集(例如,因为其他参数使用不同的损失进行训练)。
使用 stop_gradient
的直通估计器#
直通估计器是一种定义原本不可微的函数的“梯度”的技巧。给定一个不可微的函数 \(f : \mathbb{R}^n \to \mathbb{R}^n\),该函数用作我们希望找到梯度的较大函数的一部分,我们在反向传递期间简单地假装 \(f\) 是恒等函数。这可以使用 jax.lax.stop_gradient
巧妙地实现
def f(x):
return jnp.round(x) # non-differentiable
def straight_through_f(x):
# Create an exactly-zero expression with Sterbenz lemma that has
# an exactly-one gradient.
zero = x - jax.lax.stop_gradient(x)
return zero + jax.lax.stop_gradient(f(x))
print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))
print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
f(x): 3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0
每个示例的梯度#
虽然大多数 ML 系统出于计算效率和/或方差减少的原因,从批量数据中计算梯度和更新,但有时有必要访问与批次中每个特定样本关联的梯度/更新。
例如,这对于根据梯度大小来优先处理数据,或者逐个样本地应用裁剪/归一化是必需的。
在许多框架(PyTorch、TF、Theano)中,计算每个示例的梯度通常不是一件容易的事,因为库会直接累积批次上的梯度。天真的解决方法,例如计算每个示例的单独损失,然后聚合生成的梯度,通常效率非常低。
在 JAX 中,您可以以简单但有效的方式定义代码来计算每个样本的梯度。
只需将 jax.jit()
、jax.vmap()
和 jax.grad()
转换结合在一起
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
让我们一次完成一个转换。
首先,您将 jax.grad()
应用于 td_loss
,以获得一个函数,该函数计算单个(未批处理)输入上损失相对于参数的梯度
dtdloss_dtheta = jax.grad(td_loss)
dtdloss_dtheta(theta, s_tm1, r_t, s_t)
Array([ 1.2, 2.4, -1.2], dtype=float32)
此函数计算上述数组的一行。
然后,您使用 jax.vmap()
对此函数进行矢量化。这会向所有输入和输出添加批次维度。现在,给定一批输入,您会生成一批输出——批次中的每个输出都对应于输入批次中相应成员的梯度。
almost_perex_grads = jax.vmap(dtdloss_dtheta)
batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
这并不是我们想要的,因为我们必须手动向此函数提供一批 theta
,而我们实际上想使用单个 theta
。我们通过将 in_axes
添加到 jax.vmap()
来解决此问题,将 theta 指定为 None
,将其他参数指定为 0
。这使得生成的函数仅向其他参数添加额外的轴,而使 theta
保持未批处理,正如我们想要的那样
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
这个函数可以实现我们想要的功能,但速度比它应有的速度慢。现在,您可以使用 jax.jit()
将整个函数包装起来,以获得该函数的编译优化版本。
perex_grads = jax.jit(inefficient_perex_grads)
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
Array([[ 1.2, 2.4, -1.2],
[ 1.2, 2.4, -1.2]], dtype=float32)
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
3.63 ms ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.83 μs ± 16.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
使用 jax.grad
的 jax.grad
计算 Hessian 向量积#
使用高阶 jax.vmap()
可以做的一件事是构建 Hessian 向量积函数。(稍后您将编写一个更高效的实现,它混合了前向模式和反向模式,但这里我们将使用纯反向模式。)
Hessian 向量积函数在 截断牛顿共轭梯度算法 中可用于最小化平滑凸函数,或用于研究神经网络训练目标的曲率(例如,1, 2, 3, 4)。
对于具有连续二阶导数的标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\)(因此 Hessian 矩阵是对称的),在点 \(x \in \mathbb{R}^n\) 处的 Hessian 记为 \(\partial^2 f(x)\)。Hessian 向量积函数可以计算
\(\qquad v \mapsto \partial^2 f(x) \cdot v\)
对于任意 \(v \in \mathbb{R}^n\)。
诀窍是不实例化完整的 Hessian 矩阵:如果 \(n\) 很大,可能在神经网络的背景下达到数百万或数十亿,那么这可能无法存储。
幸运的是,jax.vmap()
已经提供了一种编写高效 Hessian 向量积函数的方法。您只需要使用恒等式
\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\),
其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将 \(f\) 在 \(x\) 处的梯度与向量 \(v\) 进行点积运算。请注意,您始终只对向量值参数的标量值函数进行微分,这正是您知道 jax.vmap()
高效的原因。
在 JAX 代码中,您可以这样写
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这个例子表明您可以自由使用词法闭包,JAX 永远不会感到困惑或混淆。
一旦您学会如何计算稠密 Hessian 矩阵,您将在下面几个单元格中检查此实现。您还将编写一个使用前向模式和反向模式的更佳版本。
使用 jax.jacfwd
和 jax.jacrev
计算雅可比矩阵和 Hessian 矩阵#
您可以使用 jax.jacfwd()
和 jax.jacrev()
函数计算完整的雅可比矩阵
from jax import jacfwd, jacrev
# Define a sigmoid function.
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981758 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188288 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
这两个函数计算相同的值(直到机器数值精度),但在实现上有所不同:jax.jacfwd()
使用前向模式自动微分,这对于“高”雅可比矩阵(输出多于输入)更有效,而 jax.jacrev()
使用反向模式,这对于“宽”雅可比矩阵(输入多于输出)更有效。对于接近正方形的矩阵,jax.jacfwd()
可能比 jax.jacrev()
更具优势。
您还可以将 jax.jacfwd()
和 jax.jacrev()
与容器类型一起使用
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]
有关前向模式和反向模式的更多详细信息,以及如何尽可能有效地实现 jax.jacfwd()
和 jax.jacrev()
,请继续阅读!
使用这两个函数的组合,我们可以计算稠密的 Hessian 矩阵
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285465 0.04922541 0.03384247]
[ 0.04922541 0.10602397 0.07289147]
[ 0.03384247 0.07289147 0.05011288]]
[[-0.03195215 0.03921401 -0.00544639]
[ 0.03921401 -0.04812629 0.00668421]
[-0.00544639 0.00668421 -0.00092836]]
[[-0.01583708 -0.00182736 0.03959271]
[-0.00182736 -0.00021085 0.00456839]
[ 0.03959271 0.00456839 -0.09898177]]
[[-0.00103524 0.00348343 -0.00194457]
[ 0.00348343 -0.01172127 0.0065432 ]
[-0.00194457 0.0065432 -0.00365263]]]
这个形状是有意义的:如果您从函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在点 \(x \in \mathbb{R}^n\) 处,您应该得到以下形状
\(f(x) \in \mathbb{R}^m\),\(f\) 在 \(x\) 处的值,
\(\partial f(x) \in \mathbb{R}^{m \times n}\),在 \(x\) 处的雅可比矩阵,
\(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),在 \(x\) 处的 Hessian 矩阵,
等等。
要实现 hessian
,您可以使用 jacfwd(jacrev(f))
或 jacrev(jacfwd(f))
或这两个的任何其他组合。但前向模式叠加在反向模式之上通常是最有效的。这是因为在内部雅可比矩阵计算中,我们通常对一个宽雅可比矩阵函数求导(可能像损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\)),而在外部雅可比矩阵计算中,我们对一个具有方雅可比矩阵的函数求导(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这正是前向模式获胜的地方。
它是如何制作的:两个基础自动微分函数#
雅可比矩阵-向量积(JVPs,又名,前向模式自动微分)#
JAX 包含前向模式和反向模式自动微分的高效通用实现。熟悉的 jax.vmap()
函数是基于反向模式构建的,但为了解释这两种模式之间的区别以及何时可以使用每种模式,您需要一些数学背景知识。
数学中的 JVP#
在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\),\(f\) 在输入点 \(x \in \mathbb{R}^n\) 处计算的雅可比矩阵,记为 \(\partial f(x)\),通常被认为是 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的矩阵
\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).
但您也可以将 \(\partial f(x)\) 视为一个线性映射,它将 \(f\) 的定义域在点 \(x\) 处的切空间(它只是 \(\mathbb{R}^n\) 的另一个副本)映射到 \(f\) 的值域在点 \(f(x)\) 处的切空间(\(\mathbb{R}^m\) 的副本)
\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).
此映射称为 \(f\) 在 \(x\) 处的前推映射。雅可比矩阵只是此线性映射在标准基上的矩阵。
如果您不确定一个具体的输入点 \(x\),那么您可以将函数 \(\partial f\) 视为首先获取一个输入点,并返回该输入点处的雅可比线性映射
\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).
特别是,您可以将事物解耦,以便在给定输入点 \(x \in \mathbb{R}^n\) 和切向量 \(v \in \mathbb{R}^n\) 的情况下,您会得到 \(\mathbb{R}^m\) 中的输出切向量。我们将该映射(从 \((x, v)\) 对到输出切向量)称为雅可比矩阵-向量积,并将其写为
\(\qquad (x, v) \mapsto \partial f(x) v\)
JAX 代码中的 JVP#
回到 Python 代码中,JAX 的 jax.jvp()
函数对这个转换建模。给定一个求值 \(f\) 的 Python 函数,JAX 的 jax.jvp()
是一种获取求值 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的 Python 函数的方法。
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
用 类似 Haskell 的类型签名来说,您可以写
jvp :: (a -> b) -> a -> T a -> (b, T b)
其中 T a
用于表示 a
的切空间的类型。
换句话说,jvp
接受一个类型为 a -> b
的函数、一个类型为 a
的值和一个类型为 T a
的切向量值作为参数。它返回一个由类型为 b
的值和类型为 T b
的输出切向量组成的对。
jvp
转换后的函数在计算时与原始函数非常相似,但它会将类型为 a
的每个原始值与类型为 T a
的切向量值配对。对于原始函数会应用的每个基本数值运算,jvp
转换后的函数都会执行该基本运算的 “JVP 规则”,该规则既对原始值求值,又在这些原始值上应用该基本运算的 JVP。
这种求值策略对计算复杂度有一些直接的影响。由于我们在计算过程中求值 JVP,所以不需要为以后存储任何内容,因此内存成本与计算深度无关。此外,jvp
转换后的函数的 FLOP 成本大约是仅评估函数成本的 3 倍(评估原始函数的工作量为 1 个单位,例如 sin(x)
;线性化的工作量为 1 个单位,例如 cos(x)
;以及将线性化函数应用于向量的工作量为 1 个单位,例如 cos_x * v
)。换句话说,对于固定的原始点 \(x\),我们可以用与评估 \(f\) 相同的边际成本评估 \(v \mapsto \partial f(x) \cdot v\)。
内存复杂度听起来很有吸引力!那么为什么我们在机器学习中不经常看到前向模式呢?
要回答这个问题,首先考虑一下如何使用 JVP 来构建完整的雅可比矩阵。如果我们对一个 one-hot 切向量应用 JVP,它会揭示雅可比矩阵的一列,对应于我们输入的非零条目。因此,我们可以一次构建一个完整的雅可比矩阵的列,而获取每列的成本与一次函数评估的成本大致相同。对于具有“高”雅可比矩阵的函数,这将是高效的,但对于“宽”雅可比矩阵则效率低下。
如果在机器学习中进行基于梯度的优化,你可能希望最小化一个从 \(\mathbb{R}^n\) 中的参数到 \(\mathbb{R}\) 中的标量损失值的损失函数。这意味着此函数的雅可比矩阵是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其与梯度向量 \(\nabla f(x) \in \mathbb{R}^n\) 关联起来。一次构建该矩阵的一列,每次调用所花费的 FLOP 数与评估原始函数所花费的 FLOP 数相似,这肯定显得效率低下!特别是对于训练神经网络而言,其中 \(f\) 是训练损失函数,并且 \(n\) 可以达到数百万甚至数十亿,这种方法是无法扩展的。
为了更好地处理此类函数,你只需要使用反向模式。
向量-雅可比乘积(VJP,又名反向模式自动微分)#
前向模式为我们提供了一个用于评估雅可比-向量乘积的函数,我们可以使用它一次构建雅可比矩阵的一列,而反向模式是一种获取用于评估向量-雅可比乘积(等效于雅可比-转置-向量乘积)的函数的方法,我们可以使用它一次构建雅可比矩阵的一行。
数学中的 VJP#
让我们再次考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。从我们对 JVP 的表示法开始,VJP 的表示法非常简单
\(\qquad (x, v) \mapsto v \partial f(x)\),
其中 \(v\) 是 \(f\) 在 \(x\) 处的余切空间的元素(与 \(\mathbb{R}^m\) 的另一个副本同构)。在严格意义上,我们应该将 \(v\) 视为线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),并且当我们写 \(v \partial f(x)\) 时,我们的意思是函数组合 \(v \circ \partial f(x)\),其中类型可以正确匹配,因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但是在常见情况下,我们可以将 \(v\) 与 \(\mathbb{R}^m\) 中的向量标识,并且几乎可以互换地使用两者,就像我们有时会在“列向量”和“行向量”之间翻转而无需过多注释一样。
通过这种标识,我们也可以将 VJP 的线性部分视为 JVP 的线性部分的转置(或伴随共轭)
\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).
对于给定点 \(x\),我们可以将签名写为
\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).
余切空间上的相应映射通常称为 \(f\) 在 \(x\) 处的拉回。对于我们的目的而言,关键在于它从看起来像 \(f\) 的输出的内容转到看起来像 \(f\) 的输入的内容,就像我们可能从转置线性函数中期望的那样。
JAX 代码中的 VJP#
从数学回到 Python,JAX 函数 vjp
可以接受一个用于评估 \(f\) 的 Python 函数,并返回一个用于评估 VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 Python 函数。
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
用类似 Haskell 的类型签名来说,我们可以写成
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
其中我们使用 CT a
来表示 a
的余切空间的类型。换句话说,vjp
接受一个类型为 a -> b
的函数和一个类型为 a
的点作为参数,并返回一个由类型为 b
的值和一个类型为 CT b -> CT a
的线性映射组成的对。
这很棒,因为它允许我们一次构建雅可比矩阵的一行,并且评估 \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的 FLOP 成本仅约为评估 \(f\) 的三倍。特别是,如果我们想要函数 \(f : \mathbb{R}^n \to \mathbb{R}\) 的梯度,我们只需调用一次即可完成。这就是为什么 jax.vmap()
对于基于梯度的优化来说是高效的,即使对于数百万甚至数十亿参数的神经网络训练损失函数等目标也是如此。
虽然 FLOP 很友好,但有一个成本,即内存会随着计算的深度而增加。此外,该实现传统上比前向模式的实现更复杂,尽管 JAX 有一些技巧(这是未来笔记本的故事!)。
有关反向模式如何工作的更多信息,请查看2017 年深度学习暑期学校的此教程视频。
使用 VJP 的向量值梯度#
如果你有兴趣获取向量值梯度(如 tf.gradients
)
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
[6. 6.]]
使用前向模式和反向模式的 Hessian-向量乘积#
在上一节中,你仅使用反向模式实现了一个 Hessian-向量乘积函数(假设连续二阶导数)
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这很高效,但是你可以通过将前向模式与反向模式结合使用来做得更好并节省一些内存。
在数学上,给定一个要微分的函数 \(f : \mathbb{R}^n \to \mathbb{R}\)、一个要在线性化该函数的点 \(x \in \mathbb{R}^n\) 和一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量乘积函数是
\((x, v) \mapsto \partial^2 f(x) v\)
考虑辅助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。你只需要它的 JVP,因为它会给我们
\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).
我们可以几乎直接将其转换为代码
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
更好的是,由于你无需直接调用 jnp.dot()
,因此此 hvp
函数可以处理任何形状的数组和任意容器类型(如存储为嵌套列表/字典/元组的向量),甚至不依赖于 jax.numpy
。
这是一个如何使用它的示例
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
你可能考虑编写此代码的另一种方法是使用反向优先于前向
# Reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
但这并不太好,因为前向模式的开销比反向模式小,并且由于此处的外部微分运算符必须微分比内部微分更大的计算,因此将前向模式保留在外部效果最佳
# Reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
3.12 ms ± 142 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
The slowest run took 4.27 times longer than the fastest. This could mean that an intermediate result is being cached.
7.97 ms ± 5.86 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
The slowest run took 4.20 times longer than the fastest. This could mean that an intermediate result is being cached.
12.1 ms ± 8.79 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
36.6 ms ± 2.21 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
组合 VJP、JVP 和 jax.vmap
#
雅可比矩阵和矩阵-雅可比乘积#
既然您已经有了 jax.jvp()
和 jax.vjp()
转换,它们可以为您提供一次推送或拉回单个向量的函数,那么您可以使用 JAX 的 jax.vmap()
转换 一次性推送和拉回整个基。特别是,您可以使用它来编写快速的矩阵-雅可比矩阵和雅可比矩阵-矩阵乘积。
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
96.1 ms ± 312 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
3.12 ms ± 52.6 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_616/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
124 ms ± 29.4 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
1.46 ms ± 37.7 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
jax.jacfwd
和 jax.jacrev
的实现#
现在我们已经了解了快速的雅可比矩阵-矩阵和矩阵-雅可比矩阵乘积,不难猜到如何编写 jax.jacfwd()
和 jax.jacrev()
。我们只是使用相同的技术一次性推送或拉回整个标准基(与单位矩阵同构)。
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
有趣的是,Autograd 库无法做到这一点。Autograd 中反向模式 jacobian
的实现必须使用外循环 map
一次拉回一个向量。一次性通过计算推送一个向量的效率远低于使用 jax.vmap()
将所有向量批量处理。
Autograd 无法做的另一件事是 jax.jit()
。有趣的是,无论您在要微分的函数中使用多少 Python 动态性,我们总可以在计算的线性部分使用 jax.jit()
。例如
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
复数和微分#
JAX 在处理复数和微分方面非常出色。为了支持全纯和非全纯微分,有助于从 JVP 和 VJP 的角度思考。
考虑一个复数到复数的函数 \(f: \mathbb{C} \to \mathbb{C}\),并将其与相应的函数 \(g: \mathbb{R}^2 \to \mathbb{R}^2\) 对应起来,
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
也就是说,我们分解了 \(f(z) = u(x, y) + v(x, y) i\),其中 \(z = x + y i\),并将 \(\mathbb{C}\) 与 \(\mathbb{R}^2\) 对应起来以得到 \(g\)。
由于 \(g\) 仅涉及实数输入和输出,我们已经知道如何为其编写雅可比矩阵-向量乘积,例如给定切向量 \((c, d) \in \mathbb{R}^2\),即
\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
要获取应用于切向量 \(c + di \in \mathbb{C}\) 的原始函数 \(f\) 的 JVP,我们只需使用相同的定义并将结果标识为另一个复数,
\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\).
这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数的 JVP 的定义!请注意,\(f\) 是否全纯并不重要:JVP 是明确的。
这是一个检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
那么 VJP 呢?我们做一些非常相似的事情:对于余切向量 \(c + di \in \mathbb{C}\),我们将 \(f\) 的 VJP 定义为
\((c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix}\).
负号是怎么回事?它们只是为了处理复共轭,以及我们正在使用协向量这一事实。
这是对 VJP 规则的检查
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
那么像 jax.grad()
, jax.jacfwd()
和 jax.jacrev()
这样的便捷包装器呢?
对于 \(\mathbb{R} \to \mathbb{R}\) 函数,回想一下我们定义 grad(f)(x)
为 vjp(f, x)[1](1.0)
,之所以有效是因为将 VJP 应用于 1.0
值会揭示梯度(即雅可比矩阵或导数)。我们可以对 \(\mathbb{C} \to \mathbb{R}\) 函数执行相同的操作:我们仍然可以使用 1.0
作为余切向量,并且我们只会得到一个复数结果,该结果总结了完整的雅可比矩阵。
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,雅可比矩阵具有 4 个实数值自由度(如上面的 2x2 雅可比矩阵中所示),因此我们不能希望在复数中表示所有这些自由度。但是对于全纯函数,我们可以!全纯函数恰好是一个 \(\mathbb{C} \to \mathbb{C}\) 函数,其导数可以用单个复数表示。(柯西-黎曼方程确保上述 2x2 雅可比矩阵在复平面中具有缩放和旋转矩阵的特殊形式,即单个复数在乘法下的作用。)我们可以使用单个对 vjp
的调用和一个 1.0
的协向量来揭示该复数。
因为这仅适用于全纯函数,所以要使用此技巧,我们需要向 JAX 保证我们的函数是全纯的。否则,当 jax.grad()
用于复数输出函数时,JAX 会引发错误。
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)
所有 holomorphic=True
承诺所做的事情只是在输出为复数值时禁用错误。当函数不是全纯时,我们仍然可以编写 holomorphic=True
,但是我们得到的结果不会代表完整的雅可比矩阵。相反,它将是我们只丢弃输出的虚部的函数的雅可比矩阵。
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
关于 jax.grad()
在这里的工作方式,有一些有用的结果
我们可以对全纯的 \(\mathbb{C} \to \mathbb{C}\) 函数使用
jax.grad()
。我们可以使用
jax.grad()
来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复数参数x
的实值损失函数,通过沿grad(f)(x)
的共轭方向进行步进。如果我们有一个恰好在内部使用一些复数值运算的 \(\mathbb{R} \to \mathbb{R}\) 函数(其中一些必须是非全纯的,例如卷积中使用的 FFT),那么
jax.grad()
仍然有效,并且我们得到的结果与仅使用实数值的实现所给出的结果相同。
在任何情况下,JVP 和 VJP 始终是明确的。如果我们想计算非全纯的 \(\mathbb{C} \to \mathbb{C}\) 函数的完整雅可比矩阵,我们可以使用 JVP 或 VJP 来完成!
您应该期望复数在 JAX 中的任何地方都有效。这里是通过复矩阵的 Cholesky 分解进行微分
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
用于 JAX 可转换 Python 函数的自定义导数规则#
在 JAX 中定义微分规则有两种方法
使用
jax.custom_jvp()
和jax.custom_vjp()
为已经可进行 JAX 转换的 Python 函数定义自定义微分规则;以及定义新的
core.Primitive
实例及其所有转换规则,例如,从其他系统(如求解器、模拟器或通用数值计算系统)调用函数。
本笔记本是关于 #1 的。要阅读有关 #2 的内容,请参阅关于添加基元的笔记本。
TL;DR:使用 jax.custom_jvp()
自定义 JVP#
from jax import custom_jvp
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
# Equivalent alternative using the `defjvps` convenience wrapper
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
TL;DR:使用 jax.custom_vjp
自定义 VJP#
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by `f_bwd`.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in `f_fwd`
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
示例问题#
为了了解 jax.custom_jvp()
和 jax.custom_vjp()
旨在解决哪些问题,让我们回顾几个示例。下一节将更深入地介绍 jax.custom_jvp()
和 jax.custom_vjp()
API。
示例:数值稳定性#
jax.custom_jvp()
的一个应用是提高微分的数值稳定性。
假设我们想编写一个名为 log1pexp
的函数,它计算 \(x \mapsto \log ( 1 + e^x )\)。我们可以使用 jax.numpy
来编写它
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)
由于它是用 jax.numpy
编写的,因此它是可 JAX 转换的
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
但是这里潜藏着一个数值稳定性问题
print(grad(log1pexp)(100.))
nan
这看起来不太对!毕竟,\(x \mapsto \log (1 + e^x)\) 的导数是 \(x \mapsto \frac{e^x}{1 + e^x}\),因此对于较大的 \(x\) 值,我们期望该值大约为 1。
我们可以通过查看梯度计算的 jaxpr 来更深入地了解正在发生的事情
from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.)
{ lambda ; a:f32[]. let
b:f32[] = exp a
c:f32[] = add 1.0 b
_:f32[] = log c
d:f32[] = div 1.0 c
e:f32[] = mul d b
in (e,) }
逐步执行 jaxpr 的评估过程,请注意,最后一行将涉及将浮点数学四舍五入为 0 和 \(\infty\) 的值相乘,这绝不是一个好主意。也就是说,我们实际上是在针对较大的 x
评估 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)
,这实际上变成了 0. * jnp.inf
。
与其生成如此大和如此小的值,并希望浮点数无法总是提供的抵消,不如我们将导数函数表达为更数值稳定的程序。特别是,我们可以编写一个更接近于计算相等数学表达式 \(1 - \frac{1}{1 + e^x}\) 的程序,其中没有抵消。
这个问题很有趣,因为即使我们对 log1pexp
的定义已经可以进行 JAX 微分(并使用 jax.jit()
、jax.vmap()
等进行转换),但我们对将标准自动微分规则应用于组成 log1pexp
的原语并组成结果感到不满意。相反,我们希望指定应如何作为一个整体对整个函数 log1pexp
进行微分,从而更好地安排这些指数。
这是对已经可 JAX 转换的 Python 函数使用自定义导数规则的一种应用:指定应如何对复合函数进行微分,同时仍将其原始 Python 定义用于其他转换(如 jax.jit()
、jax.vmap()
等)。
这是使用 jax.custom_jvp()
的解决方案
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = log1pexp(x)
ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
return ans, ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
这是一个 defjvps
便利包装器,用于表达相同的事情
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
示例:强制执行微分约定#
一个相关的应用是强制执行微分约定,可能在边界处。
考虑函数 \(f : \mathbb{R}_+ \to \mathbb{R}_+\),其中 \(f(x) = \frac{x}{1 + \sqrt{x}}\),我们取 \(\mathbb{R}_+ = [0, \infty)\)。我们可以像这样实现 \(f\) 的程序
def f(x):
return x / (1 + jnp.sqrt(x))
作为 \(\mathbb{R}\)(完整实数线)上的数学函数,\(f\) 在零处不可微(因为从左侧定义的导数极限不存在)。相应地,自动微分会生成 nan
值
print(grad(f)(0.))
nan
但从数学上讲,如果我们将 \(f\) 视为 \(\mathbb{R}_+\) 上的函数,则它在 0 处可微 [Rudin 的《数学分析原理》定义 5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和示例 10.1.6]。或者,我们可能说按照约定,我们希望考虑从右侧的定向导数。因此,Python 函数 grad(f)
在 0.0
处返回一个有意义的值,即 1.0
。默认情况下,JAX 的微分机制假定所有函数都定义在 \(\mathbb{R}\) 上,因此此处不会生成 1.0
。
我们可以使用自定义 JVP 规则!特别是,我们可以根据 \(\mathbb{R}_+\) 上的导数函数 \(x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}\) 定义 JVP 规则,
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
return ans, ans_dot
print(grad(f)(0.))
1.0
这是便利包装器版本
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
print(grad(f)(0.))
1.0
示例:梯度裁剪#
虽然在某些情况下我们希望表达数学微分计算,但在其他情况下,我们甚至可能希望脱离数学来调整自动微分执行的计算。一个典型的例子是反向模式梯度裁剪。
对于梯度裁剪,我们可以将 jnp.clip()
与 jax.custom_vjp()
仅反向模式规则一起使用
from functools import partial
@custom_vjp
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save bounds as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
import matplotlib.pyplot as plt
t = jnp.linspace(0, 10, 1000)
plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x7fcc97fefcd0>]
def clip_sin(x):
x = clip_gradient(-0.75, 0.75, x)
return jnp.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x7fcc95c9c1f0>]
示例:Python 调试#
另一个受开发工作流程而不是数值驱动的应用是在反向模式自动微分的反向传递中设置 pdb
调试器跟踪。
当尝试跟踪 nan
运行时错误的来源时,或者只是仔细检查正在传播的余切(梯度)值时,在反向传递中插入一个调试器,该调试器对应于原始计算中的特定点,这会很有用。您可以使用 jax.custom_vjp()
来做到这一点。
我们将把示例推迟到下一节。
示例:迭代实现的隐式函数微分#
这个例子在数学上有点深奥!
jax.custom_vjp()
的另一个应用是对可 JAX 转换(通过 jax.jit()
、jax.vmap()
等)但出于某些原因不能有效地进行 JAX 微分的函数进行反向模式微分,这可能是因为它们涉及 jax.lax.while_loop()
。(不可能生成一个 XLA HLO 程序来有效地计算 XLA HLO While 循环的反向模式导数,因为这将需要一个具有无界内存使用的程序,这在 XLA HLO 中无法表达,至少没有通过 infeed/outfeed 的“副作用”交互)。
例如,考虑这个 fixed_point
例程,它通过在 while_loop
中迭代应用函数来计算不动点
from jax.lax import while_loop
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
这是一种迭代过程,用于通过迭代 \(x_{t+1} = f(a, x_t)\) 直到 \(x_{t+1}\) 足够接近 \(x_t\),来数值求解方程 \(x = f(a, x)\) 中的 \(x\)。结果 \(x^*\) 取决于参数 \(a\),因此我们可以认为存在一个函数 \(a \mapsto x^*(a)\),该函数由方程 \(x = f(a, x)\) 隐式定义。
我们可以使用 fixed_point
来运行迭代过程以实现收敛,例如,运行牛顿法来计算平方根,同时仅执行加法、乘法和除法
def newton_sqrt(a):
update = lambda a, x: 0.5 * (x + a / x)
return fixed_point(update, a, a)
print(newton_sqrt(2.))
1.4142135
我们也可以 jax.vmap()
或 jax.jit()
该函数
print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1. 1.4142135 1.7320509 2. ]
由于 while_loop
,我们无法应用反向模式自动微分,但事实证明我们无论如何都不想这样做:与其通过 fixed_point
的实现及其所有迭代进行微分,不如利用数学结构来执行更节省内存(在这种情况下也更节省 FLOP!)的操作。相反,我们可以使用隐函数定理 [Bertsekas 的《非线性规划》第 2 版的命题 A.25],该定理(在某些条件下)保证了我们将要使用的数学对象的存在。本质上,我们将解线性化,并迭代求解这些线性方程以计算我们想要的导数。
再次考虑方程 \(x = f(a, x)\) 和函数 \(x^*\)。我们想评估向量-雅可比乘积,如 \(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)\)。
至少在我们想要微分的点 \(a_0\) 周围的开放邻域中,让我们假设方程 \(x^*(a) = f(a, x^*(a))\) 对于所有 \(a\) 都成立。由于等式的两边作为 \(a\) 的函数相等,因此它们的导数也必须相等,因此让我们对两边进行微分
\(\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)\).
设定 \(A = \partial_1 f(a_0, x^*(a_0))\) 和 \(B = \partial_0 f(a_0, x^*(a_0))\),我们可以将我们所求的量更简洁地写成
\(\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)\),
或者,通过重新整理,得到
\(\qquad \partial x^*(a_0) = (I - A)^{-1} B\).
这意味着我们可以评估向量-雅可比矩阵的乘积,例如
\(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B\),
其中 \(w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}\),或者等价地 \(w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A\),或者等价地 \(w^\mathsf{T}\) 是映射 \(u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A\) 的不动点。最后一个特征为我们提供了一种用 fixed_point
的调用来编写 fixed_point
的 VJP 的方法!此外,在展开 \(A\) 和 \(B\) 之后,您可以得出结论,您只需要评估 \(f\) 在 \((a_0, x^*(a_0))\) 处的 VJP 即可。
这是重点
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
def fixed_point_fwd(f, a, x_init):
x_star = fixed_point(f, a, x_init)
return x_star, (a, x_star)
def fixed_point_rev(f, res, x_star_bar):
a, x_star = res
_, vjp_a = vjp(lambda a: f(a, x_star), a)
a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, jnp.zeros_like(x_star)
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
return x_star_bar + vjp_x(u)[0]
fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
-0.088388346
我们可以通过微分 jnp.sqrt()
来检查我们的答案,它使用完全不同的实现
print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.))
0.35355338
-0.08838835
这种方法的一个限制是,参数 f
不能闭包任何涉及微分的值。也就是说,您可能会注意到我们在 fixed_point
的参数列表中显式保留了参数 a
。对于此用例,请考虑使用低级原语 lax.custom_root
,它允许在具有自定义求根函数的闭包变量中进行导数。
jax.custom_jvp
和 jax.custom_vjp
API 的基本用法#
使用 jax.custom_jvp
定义前向模式(以及间接地,反向模式)规则#
这是一个使用 jax.custom_jvp()
的规范基本示例,其中注释使用了 Haskell 类型的签名
# f :: a -> b
@custom_jvp
def f(x):
return jnp.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
f.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
print(f(3.))
y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
-0.9899925
换句话说,我们从一个原始函数 f
开始,该函数接受类型为 a
的输入并产生类型为 b
的输出。我们将其与 JVP 规则函数 f_jvp
相关联,该函数接受一对输入,分别表示类型为 a
的原始输入和类型为 T a
的相应切线输入,并产生一对输出,分别表示类型为 b
的原始输出和类型为 T b
的切线输出。切线输出应该是切线输入的线性函数。
您也可以使用 f.defjvp
作为装饰器,如下所示
@custom_jvp
def f(x):
...
@f.defjvp
def f_jvp(primals, tangents):
...
即使我们只定义了一个 JVP 规则而没有 VJP 规则,我们也可以在 f
上使用前向和反向模式微分。JAX 将自动转置来自我们自定义 JVP 规则的切线值上的线性计算,其计算 VJP 的效率与我们手动编写规则时一样
print(grad(f)(3.))
print(grad(grad(f))(3.))
-0.9899925
-0.14112
为了使自动转置工作,JVP 规则的输出切线必须是输入切线的线性函数。否则会引发转置错误。
多个参数的工作方式如下
@custom_jvp
def f(x, y):
return x ** 2 * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
return primal_out, tangent_out
print(grad(f)(2., 3.))
12.0
defjvps
便利包装器允许我们为每个参数单独定义 JVP,并且结果将单独计算然后求和
@custom_jvp
def f(x):
return jnp.sin(x)
f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
print(grad(f)(3.))
-0.9899925
这是一个带有多个参数的 defjvps
示例
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0
作为简写,使用 defjvps
,您可以传递一个 None
值来表示特定参数的 JVP 为零
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
None)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0
调用带有关键字参数的 jax.custom_jvp()
函数,或者使用默认参数编写 jax.custom_jvp()
函数定义,只要它们可以基于标准库 inspect.signature
机制检索到的函数签名明确地映射到位置参数,都是允许的。
当您不执行微分时,函数 f
的调用方式就像它没有被 jax.custom_jvp()
装饰一样
@custom_jvp
def f(x):
print('called f!') # a harmless side-effect
return jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
print('called f_jvp!') # a harmless side-effect
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
print(f(3.))
called f!
0.14112
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.))
called f!
[0. 0.84147096 0.9092974 ]
called f!
0.14112
在微分期间会调用自定义 JVP 规则,无论是前向还是反向
y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
-0.9899925
print(grad(f)(3.))
called f_jvp!
called f!
-0.9899925
请注意,f_jvp
调用 f
来计算原始输出。在更高阶微分的上下文中,微分变换的每次应用都将使用自定义 JVP 规则,当且仅当该规则调用原始 f
来计算原始输出时。(这代表了一种基本权衡,我们无法利用规则中 f
的评估中的中间值,并且该规则也适用于更高阶微分的所有阶。)
grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
Array(-0.14112, dtype=float32, weak_type=True)
您可以将 Python 控制流与 jax.custom_jvp()
一起使用
@custom_jvp
def f(x):
if x > 0:
return jnp.sin(x)
else:
return jnp.cos(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
if x > 0:
return ans, 2 * x_dot
else:
return ans, 3 * x_dot
print(grad(f)(1.))
print(grad(f)(-1.))
2.0
3.0
使用 jax.custom_vjp
定义自定义的仅反向模式规则#
虽然 jax.custom_jvp()
足以控制前向和反向模式微分行为(通过 JAX 的自动转置),但在某些情况下,我们可能希望直接控制 VJP 规则,例如在上面介绍的后两个示例问题中。我们可以使用 jax.custom_vjp()
来实现
from jax import custom_vjp
# f :: a -> b
@custom_vjp
def f(x):
return jnp.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), jnp.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
print(f(3.))
print(grad(f)(3.))
0.14112
-0.9899925
换句话说,我们再次从一个原始函数 f
开始,该函数接受类型为 a
的输入并产生类型为 b
的输出。我们将其与两个函数 f_fwd
和 f_bwd
相关联,这两个函数描述了如何分别执行反向模式自动微分的前向和后向传递。
函数 f_fwd
描述了前向传递,不仅包括原始计算,还包括为后向传递存储的值。它的输入签名与原始函数 f
的输入签名相同,因为它接受类型为 a
的原始输入。但是作为输出,它产生一个对,其中第一个元素是原始输出 b
,第二个元素是要存储以供后向传递使用的任何类型为 c
的“剩余”数据。(此第二个输出类似于 PyTorch 的 save_for_backward 机制。)
函数 f_bwd
描述了后向传递。它接受两个输入,其中第一个是 f_fwd
生成的类型为 c
的剩余数据,第二个是与原始函数输出对应的类型为 CT b
的输出余切。它生成一个类型为 CT a
的输出,表示与原始函数输入对应的余切。特别是,f_bwd
的输出必须是长度等于原始函数参数数量的序列(例如,元组)。
因此,多个参数的工作方式如下
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
调用带有关键字参数的 jax.custom_vjp()
函数,或者使用默认参数编写 jax.custom_vjp()
函数定义,只要它们可以基于标准库 inspect.signature
机制检索到的函数签名明确地映射到位置参数,都是允许的。
与 jax.custom_jvp()
一样,如果未应用微分,则不会调用由 f_fwd
和 f_bwd
组成的自定义 VJP 规则。如果使用 jax.jit()
、jax.vmap()
或其他非微分变换对该函数进行评估或变换,则仅调用 f
。
@custom_vjp
def f(x):
print("called f!")
return jnp.sin(x)
def f_fwd(x):
print("called f_fwd!")
return f(x), jnp.cos(x)
def f_bwd(cos_x, y_bar):
print("called f_bwd!")
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
print(f(3.))
called f!
0.14112
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
-0.9899925
y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
print(f_vjp(1.))
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),)
不能在 jax.custom_vjp()
函数上使用前向模式自动微分,并且会引发错误
from jax import jvp
try:
jvp(f, (3.,), (1.,))
except TypeError as e:
print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.
如果要同时使用前向和反向模式,请改用 jax.custom_jvp()
。
我们可以将 jax.custom_vjp()
与 pdb
一起使用,以在后向传递中插入调试器跟踪
import pdb
@custom_vjp
def debug(x):
return x # acts like identity
def debug_fwd(x):
return x, x
def debug_bwd(x, g):
import pdb; pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)
def foo(x):
y = x ** 2
y = debug(y) # insert pdb in corresponding backward pass step
return jnp.sin(y)
jax.grad(foo)(3.)
> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q
更多特性和详细信息#
使用 list
/ tuple
/ dict
容器(和其他 pytree)#
您应该期望标准的 Python 容器(如列表、元组、命名元组和字典)以及它们的嵌套版本都能正常工作。一般来说,只要它们的结构根据类型约束保持一致,任何 pytrees 都是允许的。
这里有一个使用 jax.custom_jvp()
的人为示例
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
@custom_jvp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
@f.defjvp
def f_jvp(primals, tangents):
pt, = primals
pt_dot, = tangents
ans = f(pt)
ans_dot = {'a': 2 * pt.x * pt_dot.x,
'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
return ans, ans_dot
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))
以及一个使用 jax.custom_vjp()
的类似的人为示例
@custom_vjp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
def f_fwd(pt):
return f(pt), pt
def f_bwd(pt, g):
a_bar, (b0_bar, b1_bar) = g['a'], g['b']
x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
y_bar = -jnp.sin(pt.y) * b1_bar
return (Point(x_bar, y_bar),)
f.defvjp(f_fwd, f_bwd)
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))
处理不可微分的参数#
某些用例(如最后一个示例问题)需要将不可微分的参数(如函数值参数)传递给具有自定义微分规则的函数,并且还需要将这些参数传递给规则本身。在 fixed_point
的情况下,函数参数 f
就是这样一个不可微分的参数。 jax.experimental.odeint
也出现了类似的情况。
jax.custom_jvp
与 nondiff_argnums
#
使用 jax.custom_jvp()
的可选参数 nondiff_argnums
来指示这些参数。以下是一个使用 jax.custom_jvp()
的示例
from functools import partial
@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
@app.defjvp
def app_jvp(f, primals, tangents):
x, = primals
x_dot, = tangents
return f(x), 2. * x_dot
print(app(lambda x: x ** 3, 3.))
27.0
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0
注意这里的陷阱:无论这些参数在参数列表中的位置如何,它们都会被放置在相应的 JVP 规则的签名的开头。这里是另一个例子
@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
return f(g((x)))
@app2.defjvp
def app2_jvp(f, g, primals, tangents):
x, = primals
x_dot, = tangents
return f(g(x)), 3. * x_dot
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
jax.custom_vjp
与 nondiff_argnums
#
jax.custom_vjp()
也存在类似的选项,同样,约定是将不可微分的参数作为第一个参数传递给 _bwd
规则,无论它们在原始函数的签名中出现在哪里。 _fwd
规则的签名保持不变 - 它与原始函数的签名相同。这是一个例子
@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_fwd(f, x):
return f(x), x
def app_bwd(f, x, g):
return (5 * g,)
app.defvjp(app_fwd, app_bwd)
print(app(lambda x: x ** 2, 4.))
16.0
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0
有关其他用法示例,请参阅上面的 fixed_point
。
您无需将 nondiff_argnums
用于数组值参数,例如,具有整数 dtype 的参数。相反,nondiff_argnums
应该仅用于不对应于 JAX 类型(本质上不对应于数组类型)的参数值,例如 Python 可调用对象或字符串。如果 JAX 检测到 nondiff_argnums
指示的参数包含 JAX Tracer,则会引发错误。上面的 clip_gradient
函数就是一个不将 nondiff_argnums
用于整数 dtype 数组参数的好例子。
后续步骤#
还有很多其他的自动微分技巧和功能。本教程未涵盖但值得研究的主题包括:
高斯-牛顿向量积,线性化一次
自定义 VJP 和 JVP
固定点的有效导数
使用随机的 Hessian-向量积估计 Hessian 的迹
仅使用反向模式自动微分实现前向模式自动微分
对自定义数据类型求导
检查点(用于高效反向模式的二项式检查点,而非模型快照)
使用雅可比预累积优化 VJP