Autodidax2,第 1 部分:从零开始重构 JAX#

如果你想了解 JAX 的工作原理,你可以尝试阅读其源码。但这些代码很复杂,且通常没有充分的理由。本笔记提供了一个去繁就简的版本。这是一个基于基本原理的极简版 JAX。尽情享受吧!

核心理念:上下文相关解释#

JAX 由两部分组成:

  1. 一组原始操作(大致对应 NumPy API)

  2. 一组作用于这些原语的解释器(编译、自动微分等)

在这个极简版 JAX 中,我们将仅从加法和乘法这两个原始操作开始,并逐个添加解释器。假设我们有一个如下所示的用户定义函数:

def foo(x):
  return mul(x, add(x, 3.0))

我们希望在不改变 foo 实现的前提下,以不同方式解释它:在具体值上进行求值、对其求导、将其暂存为中间表示(IR)、对其进行编译等等。

实现方式如下:对于每种解释,我们将定义一个 Interpreter 对象,其中包含处理每个原始操作的规则。我们使用全局上下文变量来跟踪当前解释器。面向用户的函数 addmul 将调度到当前解释器。程序开始时,当前解释器为“求值”解释器,它仅对普通具体数据进行求值。目前为止,大致框架如下。

from enum import Enum, auto
from contextlib import contextmanager
from typing import Any

# The full (closed) set of primitive operations
class Op(Enum):
  add = auto()  # addition on floats
  mul = auto()  # multiplication on floats

# Interpreters have rules for handling each primitive operation.
class Interpreter:
  def interpret_op(self, op: Op, args: tuple[Any, ...]):
    assert False, "subclass should implement this"

# Our first interpreter is the "evaluating interpreter" which performs ordinary
# concrete evaluation.
class EvalInterpreter:
  def interpret_op(self, op, args):
    assert all(isinstance(arg, float) for arg in args)
    match op:
      case Op.add:
        x, y = args
        return x + y
      case Op.mul:
        x, y = args
        return x * y
      case _:
        raise ValueError(f"Unrecognized primitive op: {op}")

# The current interpreter is initially the evaluating interpreter.
current_interpreter = EvalInterpreter()

# A context manager for temporarily changing the current interpreter
@contextmanager
def set_interpreter(new_interpreter):
  global current_interpreter
  prev_interpreter = current_interpreter
  try:
    current_interpreter = new_interpreter
    yield
  finally:
    current_interpreter = prev_interpreter

# The user-facing functions `mul` and `add` dispatch to the current interpreter.
def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))
def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))

此时,我们可以用普通的具体输入调用 foo 并查看结果。

print(foo(2.0))
10.0

旁白:前向模式自动微分#

对于第二个解释器,我们将尝试前向模式自动微分 (AD)。如果你是第一次接触,这里有一个简短的介绍。否则,请跳至“JVPInterpreter”部分。

假设我们对 foo(x)x=2.0 处的导数感兴趣。我们可以用有限差分法来近似它:

print((foo(2.00001) - foo(2.0)) / 0.00001)
7.000009999913458

结果如预期接近 7.0。但这种计算方式需要对函数进行两次求值(更不用说舍入误差和截断误差了)。这里有个有趣的事情:我们几乎可以通过单次求值得到答案:

print(foo(2.00001))
10.0000700001

我们想要的答案 7.0 就隐藏在微小的尾数中!

理解这一点的一种方法是:foo 的初始参数 2.00001 携带了两个数据:一个“原值”(primal) 2.0 和一个“切线值”(tangent) 1.0。这种原值-切线对的表示法(2.00001)是两者的和,其中切线由一个小常数 epsilon (1e-5) 缩放。对 foo(2.00001) 进行普通求值会传播这一对数值,产生 10.0000700001 作为结果。原值和切线分量在数量级上区分明显,因此我们可以直观地将结果解读为 (10.0, 7.0),忽略最后约 1e-10 的截断误差。

前向模式微分的想法是做同样的事情,但要精确且明确(靠观察浮点数毕竟不够严谨)。我们将用实际的对 (pair) 来表示原值-切线对,而不是将它们合并到一个浮点数中。对于每个原始操作,我们定义一条规则,描述如何传播这些原值-切线对。让我们推导出这两个原语的规则。

加法很简单。考虑 x + y,其中 x = xp + xt * epsy = yp + yt * eps(“p”代表原值,“t”代表切线)。

 x + y = (xp + xt * eps) + (yp + yt * eps)
       =   (xp + yp)             # primal component
         + (xt + yt) * eps       # tangent component

结果是一个关于 eps 的一阶多项式,我们可以直接读出原值-切线对为 (xp + yp, xt + yt)。

乘法则更有趣:

 x * y = (xp + xt * eps) * (yp + yt * eps)
       =    (xp * yp)                        # primal component
          + (xp * yt + xt * yp) * eps        # tangent component
          + (xt * yt)           * eps * eps  # quadratic component, vanishes in the eps->0 limit

现在我们得到一个二阶多项式。但当 epsilon 趋于零时,二次项消失,我们的原值-切线对就变成了 (xp * yp, xp * yt + xt * yp)(在之前使用有限 eps 的例子中,正是因为这一项没有消失,才产生了 1e-10 的“截断误差”)。

将其转化为代码,我们可以写出加法和乘法的前向 AD 规则,并以此表达 foo

from dataclasses import dataclass

# A primal-tangent pair is conventionally called a "dual number"
@dataclass
class DualNumber:
  primal  : float
  tangent : float

def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:
  return DualNumber(x.primal + y.primal, x.tangent + y.tangent)

def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:
  return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)

def foo_dual(x : DualNumber) -> DualNumber:
  return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))

print (foo_dual(DualNumber(2.0, 1.0)))
DualNumber(primal=10.0, tangent=7.0)

这行得通!但重写 foo 以使用加法和乘法的 _dual 版本有些繁琐。让我们回到主程序,利用我们的解释机制来自动完成这个重写过程。

JVP 解释器#

我们将设置一个名为 JVPInterpreter(“JVP”代表“雅可比-向量积”)的新解释器,它传播这些对偶数而不是普通值。JVPInterpreter 拥有操作对偶数的方法 'add' 和 'mul'。它们会根据需要通过调用 JVPInterpreter.lift 将常量参数转换为对偶数。在我们上面手动重写的版本中,我们是通过将字面量 3.0 替换为 DualNumber(3.0, 0.0) 来实现这一点的。

# This is like DualNumber above except that is also has a pointer to the
# interpreter it belongs to, which is needed to avoid "perturbation confusion"
# in higher order differentiation.
@dataclass
class TaggedDualNumber:
  interpreter : Interpreter
  primal  : float
  tangent : float

class JVPInterpreter(Interpreter):
  def __init__(self, prev_interpreter: Interpreter):
    # We keep a pointer to the interpreter that was current when this
    # interpreter was first invoked. That's the context in which our
    # rules should run.
    self.prev_interpreter = prev_interpreter

  def interpret_op(self, op, args):
    args = tuple(self.lift(arg) for arg in args)
    with set_interpreter(self.prev_interpreter):
      match op:
        case Op.add:
          # Notice that we use `add` and `mul` here, which are the
          # interpreter-dispatching functions defined earlier.
          x, y = args
          return self.dual_number(
              add(x.primal, y.primal),
              add(x.tangent, y.tangent))

        case Op.mul:
          x, y = args
          x = self.lift(x)
          y = self.lift(y)
          return self.dual_number(
              mul(x.primal, y.primal),
              add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))

  def dual_number(self, primal, tangent):
    return TaggedDualNumber(self, primal, tangent)

  # Lift a constant value (constant with respect to this interpreter) to
  # a TaggedDualNumber.
  def lift(self, x):
    if isinstance(x, TaggedDualNumber) and x.interpreter is self:
      return x
    else:
      return self.dual_number(x, 0.0)

def jvp(f, primal, tangent):
  jvp_interpreter = JVPInterpreter(current_interpreter)
  dual_number_in = jvp_interpreter.dual_number(primal, tangent)
  with set_interpreter(jvp_interpreter):
    result = f(dual_number_in)
  dual_number_out = jvp_interpreter.lift(result)
  return dual_number_out.primal, dual_number_out.tangent

# Let's try it out:
print(jvp(foo, 2.0, 1.0))

# Because we were careful to consider nesting interpreters, higher-order AD
# works out of the box:

def derivative(f, x):
  _, tangent = jvp(f, x, 1.0)
  return tangent

def nth_order_derivative(n, f, x):
  if n == 0:
    return f(x)
  else:
    return derivative(lambda x: nth_order_derivative(n-1, f, x), x)
(10.0, 7.0)
print(nth_order_derivative(0, foo, 2.0))
10.0
print(nth_order_derivative(1, foo, 2.0))
7.0
print(nth_order_derivative(2, foo, 2.0))
2.0
# The rest are zero because `foo` is only a second-order polymonial
print(nth_order_derivative(3, foo, 2.0))
0.0
print(nth_order_derivative(4, foo, 2.0))
0.0

这里有一些微妙之处值得讨论。首先,如何判断某物相对于微分是否为常量?我们很想说“如果它不是对偶数,它就是常量”。但实际上,由另一个 JVPInterpreter 创建的对偶数也必须被视为相对于我们当前处理的 JVPInterpreter 的常量。这就是为什么我们需要在 JVPInterpreter.lift 中进行 x.interpreter is self 的检查。这在作用域内存在多个 JVPInterpreter 的高阶微分中很有用。那种因错误地将其他解释器的对偶数视为非常量而导致的 bug,在文献中被称为“扰动混淆”(perturbation confusion)。如果没有 and x.interpreter is self 检查,下面就是一个会导致结果错误的示例程序。

def f(x):
  # g is constant in its (ignored) argument `y`. Its derivative should be zero
  # but our AD will mess it up if we don't distinguish perturbations from
  # different interpreters.
  def g(y):
    return x
  should_be_zero = derivative(g, 0.0)
  return mul(x, should_be_zero)

print(derivative(f, 0.0))
0.0

另一个微妙之处:JVPInterpreter.addJVPInterpreter.mul 根据原值和切线分量上的加法和乘法描述了对偶数上的加法和乘法。但我们并不直接使用普通的 +*。相反,我们使用自己的 addmul 函数,它们会调度到当前解释器。在调用它们之前,我们将当前解释器设置为前一个解释器(即 JVPInterpreter 首次被调用时处于活动的解释器)。如果不这样做,我们将陷入无限递归,addmul 将无休止地调度到 JVPInterpreter。使用我们自己的 addmul 而非普通运算符的优势在于,我们可以嵌套这些解释器并进行高阶 AD。

此时你可能会问:我们是不是刚刚重新发明了运算符重载?Python 通过重载中缀运算符 +* 来调度到参数的 __add____mul__。我们本可以直接使用该机制,而不是整个解释器那一套吗?是的,确实可以。事实上,早期的自动微分 (AD) 文献使用“运算符重载”一词来描述这种 AD 实现风格。细节在于,我们不能仅仅依赖 Python 内置的重载,因为那只允许我们重载少数几个内置中缀运算符,而我们最终需要重载像 sincos 这样的 NumPy 级操作。所以我们需要自己的机制。

但有一个更重要的区别:我们的调度基于上下文,而传统的 Python 风格重载基于数据。这其实是 JAX 最近的发展方向。JAX 的早期版本更像传统的基于数据的重载。一个操作的解释器(JAX 术语中称为“跟踪”或 trace)是根据附加在操作参数上的数据来选择的。我们逐渐将解释器调度的决策更多地依赖于上下文而非数据(omnistaging [链接], stackless [链接])。选择基于上下文的解释而非基于数据的解释,原因在于它让实现简单得多。

尽管如此,我们确实也想利用 Python 的内置重载机制。这样我们就能在编写代码时享受使用中缀运算符 +* 的便利,而不必写出 add(..)mul(..)。但我们暂时先把这一点搁置一旁。

3. 暂存为无类型中间表示 (IR)#

我们目前见到的两种程序转换——求值和 JVP——都是从头到尾遍历输入程序。它们按照与普通求值相同的顺序逐个访问操作。自顶向下转换的一个优点是它们可以被立即执行(或“在线”执行),这意味着我们可以从上到下遍历程序并在运行过程中执行必要的转换,而无需一次性查看整个程序。

但并非所有转换都以这种方式工作。例如,死代码消除需要从底向上遍历,在向上收集使用统计信息的同时,消除那些结果未被使用的纯操作。另一种自底向上的转换是 AD 转置,我们用它来实现反向模式 AD。对于这些,我们需要先将程序“暂存”到一个 IR(内部表示)中,这是一种代表程序的中间数据结构,之后我们可以按任何我们喜欢的顺序对其进行遍历。从 Python 程序构建这种 IR 将是我们第三个也是最后一个解释器的目标。

首先,让我们定义这个 IR。我们先从无类型的 ANF IR 开始。函数(我们在 JAX 中称 IR 函数为“jaxprs”)将包含形式参数列表、操作列表和返回值。操作的每个参数必须是一个“原子”(atom),即变量或字面量。函数的返回值也是一个原子。

Var = str           # Variables are just strings in this untyped IR
Atom = Var | float  # Atoms (arguments to operations) can be variables or (float) literals

# Equation - a single line in our IR like `z = mul(x, y)`
@dataclass
class Equation:
  var  : Var         # The variable name of the result
  op   : Op          # The primitive operation we're applying
  args : tuple[Atom] # The arguments we're applying the primitive operation to

# We call an IR function a "Jaxpr", for "JAX expression"
@dataclass
class Jaxpr:
  parameters : list[Var]      # The function's formal parameters (arguments)
  equations  : list[Equation] # The body of the function, a list of instructions/equations
  return_val : Atom           # The function's return value

  def __str__(self):
    lines = []
    lines.append(', '.join(b for b in self.parameters) + ' ->')
    for eqn in self.equations:
      args_str = ', '.join(str(arg) for arg in eqn.args)
      lines.append(f'  {eqn.var} = {eqn.op}({args_str})')
    lines.append(self.return_val)
    return '\n'.join(lines)

为了从 Python 函数构建 IR,我们定义一个 StagingInterpreter,它获取每个操作并将其添加到我们目前已看到的所有操作列表中。

class StagingInterpreter(Interpreter):
  def __init__(self):
    self.equations = []         # A mutable list of all the ops we've seen so far
    self.name_counter = 0  # Counter for generating unique names

  def fresh_var(self):
    self.name_counter += 1
    return "v_" + str(self.name_counter)

  def interpret_op(self, op, args):
    binder = self.fresh_var()
    self.equations.append(Equation(binder, op, args))
    return binder

def build_jaxpr(f, num_args):
  interpreter = StagingInterpreter()
  parameters = tuple(interpreter.fresh_var() for _ in range(num_args))
  with set_interpreter(interpreter):
    result = f(*parameters)
  return Jaxpr(parameters, interpreter.equations, result)

现在我们可以为 Python 程序构建 IR 并将其打印出来。

print(build_jaxpr(foo, 1))
v_1 ->
  v_2 = Op.add(v_1, 3.0)
  v_3 = Op.mul(v_1, v_2)
v_3

我们还可以通过编写一个显式遍历操作的解释器来求值我们的 IR。

def eval_jaxpr(jaxpr, args):
  # An environment mapping variables to values
  env = dict(zip(jaxpr.parameters, args))
  def eval_atom(x): return env[x] if isinstance(x, Var) else x
  for eqn in jaxpr.equations:
    args = tuple(eval_atom(x) for x in eqn.args)
    env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)
  return eval_atom(jaxpr.return_val)

print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))
10.0

我们根据 current_interpreter.interpret_op 编写了这个解释器,这意味着我们完成了一个完整的往返:可解释的 Python 程序 -> IR -> 可解释的 Python 程序。由于结果是“可解释的”,我们可以再次对其进行微分、暂存或任何我们想要的操作。

print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))
(10.0, 7.0)

接下来……#

本教程的第一部分到此结束。我们完成了两个原语、三个解释器以及将它们编织在一起的跟踪机制。在下一部分中,我们将添加除浮点数之外的类型、错误处理、编译、反向模式 AD 和高阶原语。注意,第二部分的结构有所不同。与其试图采用既遵循代码依赖关系(例如数据结构定义需先于使用)又遵循教学依赖关系(概念实现需先于介绍)的自顶向下顺序,我们将采用一个可以按任何顺序阅读的单文件结构。