Autodidax2,第一部分:从零开始的 JAX,再来一次#

如果你想了解 JAX 的工作原理,你可以尝试阅读代码。但是代码很复杂,通常没有充分的理由。这个 notebook 提供了一个精简版本,没有了冗余。这是一个从第一性原理出发的 JAX 的最小版本。尽情享用!

主要思想:上下文相关的解释#

JAX 有两个方面

  1. 一组原始操作(大致是 NumPy API)

  2. 一组针对这些原语的解释器(编译、AD 等)

在这个最小版本的 JAX 中,我们将从两个原始操作开始,加法和乘法,我们将逐个添加解释器。假设我们有一个像这样的用户自定义函数

def foo(x):
  return mul(x, add(x, 3.0))

我们希望能够以不同的方式解释 foo,而无需更改其实现:我们希望在具体值上评估它,对其进行微分,将其暂存到 IR,编译它等等。

这是我们将要实现的方式。对于每种解释,我们将定义一个 Interpreter 对象,其中包含处理每个原始操作的规则。我们将使用全局上下文变量跟踪当前解释器。面向用户的函数 addmul 将调度到当前解释器。在程序开始时,当前解释器将是“评估”解释器,它只是评估普通具体数据上的操作。以下是到目前为止的所有内容。

from enum import Enum, auto
from contextlib import contextmanager
from typing import Any

# The full (closed) set of primitive operations
class Op(Enum):
  add = auto()  # addition on floats
  mul = auto()  # multiplication on floats

# Interpreters have rules for handling each primitive operation.
class Interpreter:
  def interpret_op(self, op: Op, args: tuple[Any, ...]):
    assert False, "subclass should implement this"

# Our first interpreter is the "evaluating interpreter" which performs ordinary
# concrete evaluation.
class EvalInterpreter:
  def interpret_op(self, op, args):
    assert all(isinstance(arg, float) for arg in args)
    match op:
      case Op.add:
        x, y = args
        return x + y
      case Op.mul:
        x, y = args
        return x * y
      case _:
        raise ValueError(f"Unrecognized primitive op: {op}")

# The current interpreter is initially the evaluating interpreter.
current_interpreter = EvalInterpreter()

# A context manager for temporarily changing the current interpreter
@contextmanager
def set_interpreter(new_interpreter):
  global current_interpreter
  prev_interpreter = current_interpreter
  try:
    current_interpreter = new_interpreter
    yield
  finally:
    current_interpreter = prev_interpreter

# The user-facing functions `mul` and `add` dispatch to the current interpreter.
def add(x, y): return current_interpreter.interpret_op(Op.add, (x, y))
def mul(x, y): return current_interpreter.interpret_op(Op.mul, (x, y))

此时,我们可以使用普通的具体输入调用 foo 并查看结果

print(foo(2.0))
10.0

题外话:前向模式自动微分#

对于我们的第二个解释器,我们将尝试前向模式自动微分 (AD)。这是一个前向模式 AD 的快速介绍,以防这是你第一次遇到它。否则,请跳到“JVPInterprer”部分。

假设我们对在 x=2.0 处评估的 foo(x) 的导数感兴趣。我们可以使用有限差分来近似它

print((foo(2.00001) - foo(2.0)) / 0.00001)
7.000009999913458

答案接近预期的 7.0。但是以这种方式计算它需要两次函数评估(更不用说舍入误差和截断误差)。不过,这里有一件有趣的事情。我们几乎可以通过一次评估得到答案

print(foo(2.00001))
10.0000700001

我们正在寻找的答案 7.0,就在不重要的数字中!

这是思考正在发生的事情的一种方式。 foo 的初始参数 2.00001 携带两个数据片段:“原始”值 2.0 和“切线”值 1.0。此原始-切线对 2.00001 的表示形式是两者的总和,其中切线按一个小的固定 epsilon 1e-5 缩放。对 foo(2.00001) 的普通评估会传播这个原始-切线对,产生 10.0000700001 作为结果。原始和切线分量在尺度上很好地分离,因此我们可以直观地将结果解释为原始-切线对 (10.0, 7.0),忽略末尾的 ~1e-10 截断误差。

前向模式微分的思想是做同样的事情,但要精确和显式地进行(用眼睛观察浮点数并不能真正扩展)。我们将原始-切线对表示为实际的对,而不是将两者都折叠成一个浮点数。对于每个原始操作,我们将有一个规则来描述如何传播这些原始切线对。让我们计算出我们两个原语的规则。

加法很容易。考虑 x + y,其中 x = xp + xt * epsy = yp + yt * eps(“p”代表“原始”,“t”代表“切线”)

 x + y = (xp + xt * eps) + (yp + yt * eps)
       =   (xp + yp)             # primal component
         + (xt + yt) * eps       # tangent component

结果是 eps 中的一阶多项式,我们可以将原始-切线对读作 (xp + yp, xt + yt)。

乘法更有趣

 x * y = (xp + xt * eps) * (yp + yt * eps)
       =    (xp * yp)                        # primal component
          + (xp * yt + xt * yp) * eps        # tangent component
          + (xt * yt)           * eps * eps  # quadratic component, vanishes in the eps->0 limit

现在我们有一个二阶多项式。但是随着 epsilon 趋于零,二次项消失了,我们的原始-切线对只是 (xp * yp, xp * yt + xt * yp)(在我们早期的有限 eps 示例中,这个项没有消失是我们有 1e-10 “截断误差”的原因)。

将其放入代码中,我们可以写下加法和乘法的前向 AD 规则,并用这些规则表示 foo

from dataclasses import dataclass

# A primal-tangent pair is conventionally called a "dual number"
@dataclass
class DualNumber:
  primal  : float
  tangent : float

def add_dual(x : DualNumber, y: DualNumber) -> DualNumber:
  return DualNumber(x.primal + y.primal, x.tangent + y.tangent)

def mul_dual(x : DualNumber, y: DualNumber) -> DualNumber:
  return DualNumber(x.primal * y.primal, x.primal * y.tangent + x.tangent * y.primal)

def foo_dual(x : DualNumber) -> DualNumber:
  return mul_dual(x, add_dual(x, DualNumber(3.0, 0.0)))

print (foo_dual(DualNumber(2.0, 1.0)))
DualNumber(primal=10.0, tangent=7.0)

这行得通!但是重写 foo 以使用 _dual 版本的加法和乘法有点乏味。让我们回到主程序,并使用我们的解释机制来自动执行重写。

JVP 解释器#

我们将设置一个新的解释器,称为 JVPInterpreter(“JVP”代表“雅可比向量积”),它传播这些对偶数而不是普通值。 JVPInterpreter 具有对偶数运算的方法 'add' 和 'mul'。它们通过调用 JVPInterpreter.lift 根据需要将常量参数转换为对偶数。在我们手动重写的版本中,我们通过将字面量 3.0 替换为 DualNumber(3.0, 0.0) 来做到这一点。

# This is like DualNumber above except that is also has a pointer to the
# interpreter it belongs to, which is needed to avoid "perturbation confusion"
# in higher order differentiation.
@dataclass
class TaggedDualNumber:
  interpreter : Interpreter
  primal  : float
  tangent : float

class JVPInterpreter(Interpreter):
  def __init__(self, prev_interpreter: Interpreter):
    # We keep a pointer to the interpreter that was current when this
    # interpreter was first invoked. That's the context in which our
    # rules should run.
    self.prev_interpreter = prev_interpreter

  def interpret_op(self, op, args):
    args = tuple(self.lift(arg) for arg in args)
    with set_interpreter(self.prev_interpreter):
      match op:
        case Op.add:
          # Notice that we use `add` and `mul` here, which are the
          # interpreter-dispatching functions defined earlier.
          x, y = args
          return self.dual_number(
              add(x.primal, y.primal),
              add(x.tangent, y.tangent))

        case Op.mul:
          x, y = args
          x = self.lift(x)
          y = self.lift(y)
          return self.dual_number(
              mul(x.primal, y.primal),
              add(mul(x.primal, y.tangent), mul(x.tangent, y.primal)))

  def dual_number(self, primal, tangent):
    return TaggedDualNumber(self, primal, tangent)

  # Lift a constant value (constant with respect to this interpreter) to
  # a TaggedDualNumber.
  def lift(self, x):
    if isinstance(x, TaggedDualNumber) and x.interpreter is self:
      return x
    else:
      return self.dual_number(x, 0.0)

def jvp(f, primal, tangent):
  jvp_interpreter = JVPInterpreter(current_interpreter)
  dual_number_in = jvp_interpreter.dual_number(primal, tangent)
  with set_interpreter(jvp_interpreter):
    result = f(dual_number_in)
  dual_number_out = jvp_interpreter.lift(result)
  return dual_number_out.primal, dual_number_out.tangent

# Let's try it out:
print(jvp(foo, 2.0, 1.0))

# Because we were careful to consider nesting interpreters, higher-order AD
# works out of the box:

def derivative(f, x):
  _, tangent = jvp(f, x, 1.0)
  return tangent

def nth_order_derivative(n, f, x):
  if n == 0:
    return f(x)
  else:
    return derivative(lambda x: nth_order_derivative(n-1, f, x), x)
(10.0, 7.0)
print(nth_order_derivative(0, foo, 2.0))
10.0
print(nth_order_derivative(1, foo, 2.0))
7.0
print(nth_order_derivative(2, foo, 2.0))
2.0
# The rest are zero because `foo` is only a second-order polymonial
print(nth_order_derivative(3, foo, 2.0))
0.0
print(nth_order_derivative(4, foo, 2.0))
0.0

有一些值得讨论的微妙之处。首先,你如何判断某物相对于微分是否是常数?很容易说“如果且仅当它不是对偶数时,它才是常数”。但实际上,由不同 JVPInterpreter 创建的对偶数也需要被视为相对于我们当前正在处理的 JVPInterpreter 的常数。这就是为什么我们需要 JVPInterpreter.lift 中的 x.interpreter is self 检查。这在存在多个 JVPInterpreter 范围的高阶微分中出现。你意外地将来自不同解释器的对偶数解释为非恒定的那种错误有时在文献中被称为“扰动混淆”。这是一个示例程序,如果我们没有在 JVPInterpreter.lift 中进行 and x.interpreter is self 检查,它会给出错误的答案。

def f(x):
  # g is constant in its (ignored) argument `y`. Its derivative should be zero
  # but our AD will mess it up if we don't distinguish perturbations from
  # different interpreters.
  def g(y):
    return x
  should_be_zero = derivative(g, 0.0)
  return mul(x, should_be_zero)

print(derivative(f, 0.0))
0.0

另一个微妙之处: JVPInterpreter.addJVPInterpreter.mul 以原始和切线分量的加法和乘法来描述对偶数上的加法和乘法。但是我们不使用普通的 +* 来实现这一点。相反,我们使用我们自己的 addmul 函数,它们调度到当前解释器。在调用它们之前,我们将当前解释器设置为先前的解释器,即首次调用 JVPInterpreter 时当前的解释器。如果我们不这样做,我们将有一个无限递归,其中 addmul 无休止地调度到 JVPInterpreter。使用我们自己的 addmul 而不是普通的 +* 的优势在于,这意味着我们可以嵌套这些解释器并进行高阶 AD。

此时你可能会想:我们是否只是重新发明了运算符重载? Python 重载了中缀运算符 +* 以调度到参数的 __add____mul__。我们可以只使用该机制而不是整个解释器业务吗?是的,实际上可以。实际上,早期的自动微分 (AD) 文献使用术语“运算符重载”来描述这种风格的 AD 实现。一个细节是我们不能完全依赖 Python 内置的重载,因为这只允许我们重载少数内置的中缀运算符,而我们最终希望重载 numpy 级别的操作,如 sincos。因此,我们需要我们自己的机制。

但有一个更重要的区别:我们的调度是基于上下文,而传统的 Python 风格的重载是基于数据。这实际上是 JAX 最近的发展。最早版本的 JAX 看起来更像传统的数据驱动的重载。操作的解释器(JAX 术语中的“trace”)将根据附加到该操作参数的数据来选择。我们逐渐使解释器调度决策更多地依赖于上下文而不是数据(omnistaging [链接], stackless [链接])。相对于数据驱动的解释,更喜欢上下文驱动的解释的原因是它使实现更加简单。

所有这些都说明了,我们希望利用 Python 的内置重载机制。这样,我们就可以获得使用中缀运算符 +* 而不是写出 add(..)mul(..) 的语法便利性。但我们现在先把它放在一边。

3. 暂存到无类型 IR#

到目前为止,我们已经看到的两个程序转换——评估和 JVP——都从上到下遍历输入程序。它们以与普通评估相同的顺序逐个访问操作。关于从上到下的转换的一个方便之处在于,它们可以急切地或“在线”地实现,这意味着我们可以从上到下评估程序,并在进行过程中执行必要的转换。我们永远不会一次查看整个程序。

但并非所有转换都以这种方式工作。例如,死代码消除需要从下到上遍历,在上升过程中收集使用统计信息,并消除结果没有用途的纯操作。另一个从下到上的转换是 AD 转置,我们用它来实现反向模式 AD。对于这些,我们需要首先将程序“暂存”到 IR(内部表示),这是一个表示程序的数据结构,然后我们可以以我们喜欢的任何顺序遍历它。从 Python 程序构建此 IR 将是我们第三个也是最后一个解释器的目标。

首先,让我们定义 IR。我们将首先进行无类型 ANF IR。一个函数(我们在 JAX 中称 IR 函数为“jaxprs”)将有一个形式参数列表、一个操作列表和一个返回值。操作的每个参数都必须是一个“原子”,它可以是一个变量或一个字面量。函数的返回值也是一个原子。

Var = str           # Variables are just strings in this untyped IR
Atom = Var | float  # Atoms (arguments to operations) can be variables or (float) literals

# Equation - a single line in our IR like `z = mul(x, y)`
@dataclass
class Equation:
  var  : Var         # The variable name of the result
  op   : Op          # The primitive operation we're applying
  args : tuple[Atom] # The arguments we're applying the primitive operation to

# We call an IR function a "Jaxpr", for "JAX expression"
@dataclass
class Jaxpr:
  parameters : list[Var]      # The function's formal parameters (arguments)
  equations  : list[Equation] # The body of the function, a list of instructions/equations
  return_val : Atom           # The function's return value

  def __str__(self):
    lines = []
    lines.append(', '.join(b for b in self.parameters) + ' ->')
    for eqn in self.equations:
      args_str = ', '.join(str(arg) for arg in eqn.args)
      lines.append(f'  {eqn.var} = {eqn.op}({args_str})')
    lines.append(self.return_val)
    return '\n'.join(lines)

为了从 Python 函数构建 IR,我们定义了一个 StagingInterpreter,它接受每个操作并将其添加到不断增长的我们到目前为止看到的所有操作的列表中

class StagingInterpreter(Interpreter):
  def __init__(self):
    self.equations = []         # A mutable list of all the ops we've seen so far
    self.name_counter = 0  # Counter for generating unique names

  def fresh_var(self):
    self.name_counter += 1
    return "v_" + str(self.name_counter)

  def interpret_op(self, op, args):
    binder = self.fresh_var()
    self.equations.append(Equation(binder, op, args))
    return binder

def build_jaxpr(f, num_args):
  interpreter = StagingInterpreter()
  parameters = tuple(interpreter.fresh_var() for _ in range(num_args))
  with set_interpreter(interpreter):
    result = f(*parameters)
  return Jaxpr(parameters, interpreter.equations, result)

现在我们可以为一个 Python 程序构造一个 IR 并将其打印出来

print(build_jaxpr(foo, 1))
v_1 ->
  v_2 = Op.add(v_1, 3.0)
  v_3 = Op.mul(v_1, v_2)
v_3

我们还可以通过编写一个显式解释器来评估我们的 IR,该解释器逐个遍历操作

def eval_jaxpr(jaxpr, args):
  # An environment mapping variables to values
  env = dict(zip(jaxpr.parameters, args))
  def eval_atom(x): return env[x] if isinstance(x, Var) else x
  for eqn in jaxpr.equations:
    args = tuple(eval_atom(x) for x in eqn.args)
    env[eqn.var] = current_interpreter.interpret_op(eqn.op, args)
  return eval_atom(jaxpr.return_val)

print(eval_jaxpr(build_jaxpr(foo, 1), (2.0,)))
10.0

我们已经用 current_interpreter.interpret_op 编写了这个解释器,这意味着我们已经完成了一个完整的往返:可解释的 Python 程序到 IR 到可解释的 Python 程序。由于结果是“可解释的”,我们可以再次对其进行微分,或将其暂存出来,或任何我们喜欢的操作

print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0))
(10.0, 7.0)

下一步…#

这就是本教程第一部分的内容。我们已经完成了两个原语、三个解释器以及将它们编织在一起的跟踪机制。在下一部分中,我们将添加浮点数以外的类型、错误处理、编译、反向模式 AD 和高阶原语。请注意,第二部分的结构不同。我们没有尝试具有一个既遵循代码依赖关系(例如,数据结构需要在使用之前定义)又遵循教学依赖关系(概念需要在实现之前引入)的从上到下的顺序,而是使用一个可以以任何顺序访问的单个文件。