有状态计算#
JAX 的转换,如 jit()、vmap()、grad(),要求它们包装的函数必须是纯函数:也就是说,函数的输出**仅**取决于输入,并且没有任何副作用,例如更新全局状态。您可以在 JAX 尖端知识:纯函数 中找到对此的讨论。
在机器学习的背景下,这种限制可能会带来一些挑战,因为状态可能以多种形式存在。例如:
模型参数,
优化器状态,以及
有状态层,例如 批量归一化。
本节将提供一些关于如何在 JAX 程序中正确处理状态的建议。
一个简单的例子:计数器#
让我们从一个简单的有状态程序开始:计数器。
import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self):
self.n = 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
counter = Counter()
for _ in range(3):
print(counter.count())
1
2
3
计数器的 n 属性在 successive 调用 count 之间维护计数器的状态。它是作为调用 count 的副作用而修改的。
假设我们想快速计数,所以我们对 count 方法进行 JIT 编译。(在这个例子中,由于许多原因,这实际上并不能帮助提速,但请将其视为对模型参数更新进行 JIT 编译的玩具模型,其中 jit() 能带来巨大的提升。)
counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
print(fast_count())
1
1
1
糟糕!我们的计数器不起作用了。这是因为 count 中的这行代码
self.n += 1
涉及副作用:它会就地修改输入计数器,因此 jit 不支持此函数。这种副作用仅在函数首次跟踪时执行,后续调用不会重复该副作用。那么,我们该如何解决呢?
解决方案:显式状态#
我们的计数器问题的一部分是返回的值不依赖于参数,这意味着一个常量被“烘焙”到了编译后的输出中。但它不应该是常量——它应该依赖于状态。那么,为什么不让状态成为一个参数呢?
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
def reset(self) -> CounterState:
return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)
1
2
3
在这个 Counter 的新版本中,我们将 n 移到了 count 的参数中,并添加了另一个表示新更新状态的返回值。为了使用这个计数器,我们现在需要显式地跟踪状态。但作为回报,我们可以安全地对这个计数器进行 jax.jit 编译
state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
1
2
3
通用策略#
我们可以将相同的过程应用于任何有状态方法,将其转换为无状态方法。我们采用了一个形式为
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:
的类,并将其转换为一个形式为
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
的类。这是一种常见的 函数式编程 模式,本质上是所有 JAX 程序处理状态的方式。
请注意,一旦我们这样重写了类,对类的需求就变得不那么明显了。我们可以只保留 stateless_method,因为类不再执行任何工作。这是因为,就像我们刚刚应用过的策略一样,面向对象编程(OOP)是一种帮助程序员理解程序状态的方式。
在我们的例子中,CounterV2 类仅仅是一个命名空间,将所有使用 CounterState 的函数集中在一个地方。读者练习:您认为保留它作为一个类是否有意义?
顺便说一句,您已经在 JAX 伪随机数 API jax.random 中看到了这个策略的示例,如 伪随机数 部分所示。与 Numpy 不同,Numpy 使用隐式更新的有状态类来管理随机状态,JAX 要求程序员直接处理随机生成器状态——PRNG 密钥。
简单的实战示例:线性回归#
让我们将这个策略应用于一个简单的机器学习模型:通过梯度下降进行线性回归。
在这里,我们只处理一种状态:模型参数。但通常,您会看到多种状态在 JAX 函数之间传递,例如优化器状态、批量归一化的层统计量等。
需要仔细查看的函数是 update。
from typing import NamedTuple
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * x + params.bias
return jnp.mean((pred - y) ** 2)
LEARNING_RATE = 0.005
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
"""Performs one SGD update step on params using the given data."""
grad = jax.grad(loss)(params, x, y)
# If we were using Adam or another stateful optimizer,
# we would also do something like
#
# updates, new_optimizer_state = optimizer(grad, optimizer_state)
#
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree.map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params
请注意,我们手动将 params 传入和传出 update 函数。
import matplotlib.pyplot as plt
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise
# Fit regression
params = init(rng)
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
更进一步#
上面描述的策略是任何 JAX 程序在使用 jit、vmap、grad 等转换时必须处理状态的方式。
如果只处理两个参数,手动处理参数看起来还可以,但如果是一个包含数十个层的神经网络呢?您可能已经开始担心两件事:
我们是否应该手动初始化所有这些参数,本质上重复我们在前向传播定义中已经写过的东西?
我们是否应该手动传递所有这些东西?
细节可能很难处理,但有一些库可以为您处理这些。有关示例,请参阅 JAX 生态系统库。