Autodidax2,第1部分:从头开始再次构建 JAX#

如果你想了解 JAX 的工作原理,可以尝试阅读代码。但代码很复杂,通常没有充分的理由。本笔记本提供了一个精简版本,去除了不必要的细节。这是一个从基本原理出发的 JAX 最小版本。请享用!

主要思想:上下文敏感的解释#

JAX 有两点

  1. 一组原始操作(大致相当于 NumPy API)

  2. 一组基于这些原始操作的解释器(编译、AD 等)

在这个 JAX 的最小版本中,我们将从两个原始操作开始:加法和乘法,然后逐个添加解释器。假设我们有一个用户定义的函数,如下所示:

def foo(x):
  return mul(x, add(x, 3.0))

我们希望能够在不改变 `foo` 实现的情况下以不同的方式解释它:我们希望在具体值上评估它、对其进行微分、将其阶段性地生成 IR、编译它等等。

我们将这样实现。对于每种解释,我们都会定义一个 `Interpreter` 对象,其中包含处理每个原始操作的规则。我们将使用一个全局上下文变量来跟踪当前解释器。面向用户的函数 `add` 和 `mul` 将分派给当前解释器。在程序开始时,当前解释器将是“求值”解释器,它只对普通的具体数据执行操作。目前看起来是这样的。

from enum import Enum, auto
from contextlib import contextmanager
from typing import Any

# The full (closed) set of primitive operations
class Op(Enum):
  add = auto()  # addition on floats
  mul = auto()  # multiplication on floats

# Interpreters have rules for handling each primitive operation.
class Interpreter:
  def interpret_op(self, op: Op, args: tuple[Any, ...]):
    assert False, "subclass should implement this"

# Our first interpreter is the "evaluating interpreter" which performs ordinary
# concrete evaluation.
class EvalInterpreter:
  def interpret_op(self, op, args):
    assert all(isinstance(arg, float) for arg in args)
    match op:
      case Op.add:
        x, y = args
        return x + y
      case Op.mul:
        x, y = args
        return x * y
      case _:
        raise ValueError(f"Unrecognized primitive op: {op}")

# The current interpreter is initially the evaluating interpreter.
current_interpreter = EvalInterpreter()

# A context manager for temporarily changing the current interpreter
@contextmanager
def set_interpreter(new_interpreter):
  global current_interpreter
  prev_interpreter = current_interpreter
  try:
    current_interpreter = new_interpreter
    yield
  finally:
    current_interpreter = prev_interpreter

# The user-facing functions `mul` and `add` dispatch to the current interpreter.
def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))
def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))

此时,我们可以使用普通的具体输入调用 `foo` 并查看结果。

print(foo(2.0))
10.0

旁注:前向模式自动微分#

对于我们的第二个解释器,我们将尝试前向模式自动微分 (AD)。如果你是第一次接触前向模式 AD,这里有一个快速介绍。否则,请跳到“JVPInterpreter”部分。

假设我们对 `foo(x)` 在 `x=2.0` 处的导数感兴趣。我们可以用有限差分来近似它:

print((foo(2.00001) - foo(2.0)) / 0.00001)
7.000009999913458

答案接近 7.0,正如预期。但是用这种方式计算需要两次函数求值(更不用说舍入误差和截断误差了)。然而,有趣的是,我们几乎可以通过一次求值得到答案:

print(foo(2.00001))
10.0000700001

我们正在寻找的答案 7.0 就在那些不重要的数字中!

我们可以这样理解所发生的事情。`foo` 的初始参数 `2.00001` 携带了两个数据:一个“原值”(primal value) 2.0,和一个“切值”(tangent value) `1.0`。这个原值-切值对 `2.00001` 的表示是两者的和,其中切值乘以了一个小的固定 epsilon,即 `1e-5`。对 `foo(2.00001)` 的普通求值会传播这个原值-切值对,产生 `10.0000700001` 作为结果。原值和切值分量在尺度上很好地分离,因此我们可以将结果直观地解释为原值-切值对 (10.0, 7.0),忽略末尾的约 1e-10 截断误差。

前向模式微分的思想是做同样的事情,但要精确和显式(目测浮点数并不真正可行)。我们将把原值-切值对表示为一个实际的对,而不是将它们折叠成一个单一的浮点数。对于每个原始操作,我们都会有一个规则来描述如何传播这些原值-切值对。让我们推导出我们两个原始操作的规则。

加法很简单。考虑 `x + y`,其中 `x = xp + xt * eps` 且 `y = yp + yt * eps`(“p”代表“原值”,“t”代表“切值”)。

 x + y = (xp + xt * eps) + (yp + yt * eps)
       =   (xp + yp)             # primal component
         + (xt + yt) * eps       # tangent component

结果是关于 `eps` 的一阶多项式,我们可以从中读出原值-切值对为 (xp + yp, xt + yt)。

乘法则更有趣:

 x * y = (xp + xt * eps) * (yp + yt * eps)
       =    (xp * yp)                        # primal component
          + (xp * yt + xt * yp) * eps        # tangent component
          + (xt * yt)           * eps * eps  # quadratic component, vanishes in the eps->0 limit

现在我们有一个二阶多项式。但是当 epsilon 趋于零时,二次项消失,我们的原值-切值对就只剩下 `(xp * yp, xp * yt + xt * yp)`(在我们早先的有限 `eps` 例子中,这个项没有消失正是我们有 1e-10 “截断误差”的原因)。

将此转化为代码,我们可以写出加法和乘法的前向 AD 规则,并用这些规则表示 `foo`:

from dataclasses import dataclass

# A primal-tangent pair is conventionally called a "dual number"
@dataclass
class DualNumber:
  primal  : float
  tangent : float

def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:
  return DualNumber(x.primal + y.primal, x.tangent + y.tangent)

def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:
  return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)

def foo_dual(x : DualNumber) -> DualNumber:
  return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))

print (foo_dual(DualNumber(2.0, 1.0)))
DualNumber(primal=10.0, tangent=7.0)

这样可行!但是将 `foo` 重写为使用加法和乘法的 `_dual` 版本有点繁琐。让我们回到主程序,使用我们的解释器机制自动完成重写。

JVP 解释器#

我们将设置一个新的解释器,名为 `JVPInterpreter`(“JVP”是“雅可比-向量积”的缩写),它传播这些双重数(dual numbers)而不是普通值。`JVPInterpreter` 具有对双重数进行操作的“add”和“mul”方法。它们通过调用 `JVPInterpreter.lift` 将常量参数按需转换为双重数。在我们上面手动重写的版本中,我们通过将字面值 `3.0` 替换为 `DualNumber(3.0, 0.0)` 来实现这一点。

# This is like DualNumber above except that is also has a pointer to the
# interpreter it belongs to, which is needed to avoid "perturbation confusion"
# in higher order differentiation.
@dataclass
class TaggedDualNumber:
  interpreter : Interpreter
  primal  : float
  tangent : float

class JVPInterpreter(Interpreter):
  def __init__(self, prev_interpreter: Interpreter):
    # We keep a pointer to the interpreter that was current when this
    # interpreter was first invoked. That's the context in which our
    # rules should run.
    self.prev_interpreter = prev_interpreter

  def interpret_op(self, op, args):
    args = tuple(self.lift(arg) for arg in args)
    with set_interpreter(self.prev_interpreter):
      match op:
        case Op.add:
          # Notice that we use `add` and `mul` here, which are the
          # interpreter-dispatching functions defined earlier.
          x, y = args
          return self.dual_number(
              add(x.primal, y.primal),
              add(x.tangent, y.tangent))

        case Op.mul:
          x, y = args
          x = self.lift(x)
          y = self.lift(y)
          return self.dual_number(
              mul(x.primal, y.primal),
              add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))

  def dual_number(self, primal, tangent):
    return TaggedDualNumber(self, primal, tangent)

  # Lift a constant value (constant with respect to this interpreter) to
  # a TaggedDualNumber.
  def lift(self, x):
    if isinstance(x, TaggedDualNumber) and x.interpreter is self:
      return x
    else:
      return self.dual_number(x, 0.0)

def jvp(f, primal, tangent):
  jvp_interpreter = JVPInterpreter(current_interpreter)
  dual_number_in = jvp_interpreter.dual_number(primal, tangent)
  with set_interpreter(jvp_interpreter):
    result = f(dual_number_in)
  dual_number_out = jvp_interpreter.lift(result)
  return dual_number_out.primal, dual_number_out.tangent

# Let's try it out:
print(jvp(foo, 2.0, 1.0))

# Because we were careful to consider nesting interpreters, higher-order AD
# works out of the box:

def derivative(f, x):
  _, tangent = jvp(f, x, 1.0)
  return tangent

def nth_order_derivative(n, f, x):
  if n == 0:
    return f(x)
  else:
    return derivative(lambda x: nth_order_derivative(n-1, f, x), x)
(10.0, 7.0)
print(nth_order_derivative(0, foo, 2.0))
10.0
print(nth_order_derivative(1, foo, 2.0))
7.0
print(nth_order_derivative(2, foo, 2.0))
2.0
# The rest are zero because `foo` is only a second-order polymonial
print(nth_order_derivative(3, foo, 2.0))
0.0
print(nth_order_derivative(4, foo, 2.0))
0.0

有一些值得讨论的微妙之处。首先,如何判断某个值在微分方面是否是常量?很想说“当且仅当它不是双重数时,它才是常量”。但实际上,由不同 JVPInterpreter 创建的双重数也需要被视为我们当前处理的 JVPInterpreter 的常量。这就是为什么我们需要在 `JVPInterpreter.lift` 中进行 `x.interpreter is self` 检查。这在存在多个 JVPInterpreter 作用域的高阶微分中会出现。这种错误,即你无意中将来自不同解释器的双重数解释为非常量,在文献中有时被称为“扰动混淆”(perturbation confusion)。如果我们在 `JVPInterpreter.lift` 中没有 `and x.interpreter is self` 检查,下面这个示例程序就会给出错误答案。

def f(x):
  # g is constant in its (ignored) argument `y`. Its derivative should be zero
  # but our AD will mess it up if we don't distinguish perturbations from
  # different interpreters.
  def g(y):
    return x
  should_be_zero = derivative(g, 0.0)
  return mul(x, should_be_zero)

print(derivative(f, 0.0))
0.0

另一个微妙之处在于:`JVPInterpreter.add` 和 `JVPInterpreter.mul` 根据原值和切值分量上的加法和乘法来描述双重数上的加法和乘法。但我们不使用普通的 `+` 和 `*` 来实现这一点。相反,我们使用我们自己的 `add` 和 `mul` 函数,它们分派给当前解释器。在调用它们之前,我们将当前解释器设置为上一个解释器,即 `JVPInterpreter` 首次被调用时处于活动状态的解释器。如果我们不这样做,就会发生无限递归,`add` 和 `mul` 会无休止地分派给 `JVPInterpreter`。使用我们自己的 `add` 和 `mul` 而不是普通的 `+` 和 `*` 的优点是,这意味着我们可以嵌套这些解释器并进行高阶 AD。

此时你可能会想:我们是不是只是重新发明了运算符重载?Python 重载了中缀运算符 `+` 和 `*`,以分派给参数的 `__add__` 和 `__mul__` 方法。我们是否可以直接使用那种机制,而不是搞出这一整套解释器业务?是的,实际上。确实,早期的自动微分 (AD) 文献使用“运算符重载”这个术语来描述这种 AD 实现风格。一个细节是,我们不能完全依赖 Python 内置的重载机制,因为那只允许我们重载少数内置的中缀运算符,而我们最终希望重载 NumPy 级别的操作,如 `sin` 和 `cos`。所以我们需要自己的机制。

但有一个更重要的区别:我们的分派是基于上下文的,而传统的 Python 风格重载是基于数据的。这实际上是 JAX 最近的发展。最早的 JAX 版本更像是传统的数据重载。操作的解释器(JAX 行话中的“trace”)将根据附加到该操作参数的数据来选择。我们逐渐使解释器分派决策越来越依赖上下文而不是数据(omnistaging [链接],stackless [链接])。之所以更倾向于基于上下文的解释而不是基于数据的解释,是因为它使实现简单得多。

尽管如此,我们希望利用 Python 的内置重载机制。这样我们就可以方便地使用中缀运算符 `+` 和 `*`,而不是写出 `add(..)` 和 `mul(..)`。但我们暂时把这一点放一边。

3. 阶段性地生成无类型 IR#

到目前为止我们看到的两种程序转换——求值和 JVP——都从上到下遍历输入程序。它们按照与普通求值相同的顺序逐个访问操作。从上到下转换的一个便利之处在于它们可以被急切地或“在线地”实现,这意味着我们可以从上到下评估程序并随之执行必要的转换。我们从不一次性查看整个程序。

但并非所有转换都以这种方式工作。例如,死代码消除需要从下到上遍历,在向上过程中收集使用统计信息并消除其结果未被使用的纯操作。另一个自下而上的转换是 AD 转置(AD transposition),我们用它来实现反向模式 AD。对于这些,我们首先需要将程序“阶段性地”生成为一个 IR(内部表示),这是一种表示程序的数据结构,然后我们可以按任何顺序遍历它。从 Python 程序构建这个 IR 将是我们的第三个也是最后一个解释器的目标。

首先,让我们定义 IR。我们将从一个无类型的 ANF IR 开始。一个函数(在 JAX 中我们称 IR 函数为“jaxprs”)将有一个形参列表、一个操作列表和一个返回值。操作的每个参数必须是一个“原子”,它要么是一个变量,要么是一个字面值。函数的返回值也是一个原子。

Var = str           # Variables are just strings in this untyped IR
Atom = Var | float  # Atoms (arguments to operations) can be variables or (float) literals

# Equation - a single line in our IR like `z = mul(x, y)`
@dataclass
class Equation:
  var  : Var         # The variable name of the result
  op   : Op          # The primitive operation we're applying
  args : tuple[Atom] # The arguments we're applying the primitive operation to

# We call an IR function a "Jaxpr", for "JAX expression"
@dataclass
class Jaxpr:
  parameters : list[Var]      # The function's formal parameters (arguments)
  equations  : list[Equation] # The body of the function, a list of instructions/equations
  return_val : Atom           # The function's return value

  def __str__(self):
    lines = []
    lines.append(', '.join(b for b in self.parameters) + ' ->')
    for eqn in self.equations:
      args_str = ', '.join(str(arg) for arg in eqn.args)
      lines.append(f'  {eqn.var} = {eqn.op}({args_str})')
    lines.append(self.return_val)
    return '\n'.join(lines)

为了从 Python 函数构建 IR,我们定义了一个 `StagingInterpreter`,它接收每个操作并将其添加到我们目前已看到的所有操作的不断增长的列表中。

class StagingInterpreter(Interpreter):
  def __init__(self):
    self.equations = []         # A mutable list of all the ops we've seen so far
    self.name_counter = 0  # Counter for generating unique names

  def fresh_var(self):
    self.name_counter += 1
    return "v_" + str(self.name_counter)

  def interpret_op(self, op, args):
    binder = self.fresh_var()
    self.equations.append(Equation(binder, op, args))
    return binder

def build_jaxpr(f, num_args):
  interpreter = StagingInterpreter()
  parameters = tuple(interpreter.fresh_var() for _ in range(num_args))
  with set_interpreter(interpreter):
    result = f(*parameters)
  return Jaxpr(parameters, interpreter.equations, result)

现在我们可以为 Python 程序构建一个 IR 并将其打印出来。

print(build_jaxpr(foo, 1))
v_1 ->
  v_2 = Op.add(v_1, 3.0)
  v_3 = Op.mul(v_1, v_2)
v_3

我们还可以通过编写一个显式解释器来逐个遍历操作,从而评估我们的 IR。

def eval_jaxpr(jaxpr, args):
  # An environment mapping variables to values
  env = dict(zip(jaxpr.parameters, args))
  def eval_atom(x): return env[x] if isinstance(x, Var) else x
  for eqn in jaxpr.equations:
    args = tuple(eval_atom(x) for x in eqn.args)
    env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)
  return eval_atom(jaxpr.return_val)

print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))
10.0

我们已经根据 `current_interpreter.interpret_op` 编写了这个解释器,这意味着我们完成了一个完整的往返:可解释的 Python 程序到 IR 再到可解释的 Python 程序。由于结果是“可解释的”,我们可以再次对其进行微分,或者阶段性地生成它,或者我们喜欢的任何操作。

print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))
(10.0, 7.0)

接下来…#

本教程第一部分到此结束。我们已经完成了两个原始操作,三个解释器以及将它们编织在一起的跟踪机制。在下一部分中,我们将添加除浮点数以外的类型、错误处理、编译、反向模式 AD 和高阶原始操作。请注意,第二部分的结构有所不同。我们不再尝试遵循既满足代码依赖(例如,数据结构在使用前需要定义)又满足教学依赖(概念在实现前需要引入)的自上而下顺序,而是采用一个可以按任何顺序阅读的单一文件。