Open inColab

Autodidax:从零开始构建JAX核心#

是否曾想学习JAX的工作原理,却觉得其实现难以理解?那么,你很幸运!通过阅读本教程,你将了解JAX核心系统中的每一个重要概念。你甚至会熟悉我们那些奇怪的行话!

这是一个正在进行中的草稿。 仍缺少一些重要部分,将在第5和第6部分(以及更多?)中补充。此外,这里还有一些简化,我们尚未将其应用于主系统,但未来会这样做。

第1部分:变换即解释器:标准求值、jvpvmap#

我们想要转换看起来像这样的函数

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

将像sin以及中缀运算符(muladdneg)背后的算术运算视为原始操作,即处理的基本单元而非组合。

“转换”意味着“以不同的方式解释”。我们不再是标准地将原始操作应用于数值输入以产生数值输出,而是要覆盖原始应用,让不同的值流经我们的程序。例如,我们可能希望将每个原始操作的应用替换为其JVP规则的应用,并让原始-切线对流经我们的程序。此外,我们希望能够组合多个转换,从而形成解释器栈。

JAX核心机制#

我们可以实现解释器栈,甚至可以在执行要转换的Python函数时,即时地让它们全部生效。首先,让我们定义这些原语,以便拦截它们的应用程序。

from typing import NamedTuple

class Primitive(NamedTuple):
  name: str

add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")

def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  if axis is None:
    axis = tuple(range(np.ndim(x)))
  if type(axis) is int:
    axis = (axis,)
  return bind1(reduce_sum_p, x, axis=axis)

def bind1(prim, *args, **params):
  out, = bind(prim, *args, **params)
  return out

我们稍后会设置数组数据类型和中缀运算符方法。

一个Primitive只是一个带有名称的对象,我们将解释规则(每个转换对应一个)附加到其上。bind函数是我们的拦截点:它将根据参数在跟踪器中如何封装以及哪些解释器处于活动状态,来确定要应用哪个转换规则。

用户代码调用的函数,例如addsin,只是对bind调用的包装器。这些包装器允许我们控制参数如何传递给bind,特别是我们遵循一个方便的内部约定:当我们调用bind时,我们将表示数组数据的值作为位置参数传递,并将元数据(例如reduce_sum_paxis参数)通过关键字传递。这种调用约定简化了一些核心逻辑(例如,下面将定义的Tracer类的实例只能出现在bind的位置参数中)。这些包装器还可以提供文档字符串!

我们用栈来表示活动的解释器。该栈只是一个简单的list,每个元素都是一个容器,包含一个整数级别(对应于元素在栈中的高度)、一个解释器类型(我们称之为trace_type),以及一个可选字段,用于解释器所需的任何全局数据。我们称每个元素为MainTrace,尽管“解释器”可能更具描述性。

from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any

class MainTrace(NamedTuple):
  level: int
  trace_type: type['Trace']
  global_data: Any | None

trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None  # to be employed in Part 3

@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
  level = len(trace_stack)
  main = MainTrace(level, trace_type, global_data)
  trace_stack.append(main)

  try:
    yield main
  finally:
    trace_stack.pop()

当我们即将应用一个转换时,我们将使用new_main将另一个解释器压入栈中。然后,当我们在函数中应用原语时,我们可以认为bind首先由栈顶的跟踪器(即级别最高的跟踪器)解释。如果第一个解释器本身在其原语解释规则中绑定了其他原语,例如sin_p的JVP规则可能会绑定cos_pmul_p,那么这些bind调用将由下一层解释器处理。

解释器栈的底部是什么?在底部,我们知道所有转换解释器都已完成,我们只想进行标准求值。因此,我们将在底部放置一个求值解释器。

让我们勾勒出解释器的接口,该接口基于TraceTracer基类。Tracer表示一个封装的值,可能携带解释器使用的额外上下文数据。Trace负责将值封装到Tracer中,并处理原始操作的应用。

class Trace:
  main: MainTrace

  def __init__(self, main: MainTrace) -> None:
    self.main = main

  def pure(self, val): assert False  # must override
  def lift(self, val): assert False  # must override

  def process_primitive(self, primitive, tracers, params):
    assert False  # must override

前两个方法是关于在Tracer中封装值,这些Tracer是我们转换的Python程序中流动对象。最后一个方法是我们用来解释原始操作应用的回调。

除了对其相应MainTrace实例的引用外,Trace本身不包含任何数据。事实上,在应用一个转换期间,可能会创建和丢弃Trace的多个实例,而每个转换的应用只创建一个MainTrace实例。

至于Tracer本身,每个都带有一个抽象值(并将其转发给中缀运算符),其余的则取决于转换。(TracerAbstractValue之间的关系是:每个转换有一个Tracer,每个基本类型(如数组)至少有一个AbstractValue。)

import numpy as np

class Tracer:
  _trace: Trace

  __array_priority__ = 1000

  @property
  def aval(self):
    assert False  # must override

  def full_lower(self):
    return self  # default implementation

  def __neg__(self): return self.aval._neg(self)
  def __add__(self, other): return self.aval._add(self, other)
  def __radd__(self, other): return self.aval._radd(self, other)
  def __mul__(self, other): return self.aval._mul(self, other)
  def __rmul__(self, other): return self.aval._rmul(self, other)
  def __gt__(self, other): return self.aval._gt(self, other)
  def __lt__(self, other): return self.aval._lt(self, other)
  def __bool__(self): return self.aval._bool(self)
  def __nonzero__(self): return self.aval._nonzero(self)

  def __getattr__(self, name):
    try:
      return getattr(self.aval, name)
    except AttributeError:
      raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")

def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
  array_abstraction_level = 1
  shape: tuple[int, ...]
  dtype: np.dtype

  def __init__(self, shape, dtype):
    self.shape = shape
    self.dtype = dtype

  @property
  def ndim(self):
    return len(self.shape)

  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(swap(add))
  _mul = staticmethod(mul)
  _rmul = staticmethod(swap(mul))
  _gt = staticmethod(greater)
  _lt = staticmethod(less)

  @staticmethod
  def _bool(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")

  @staticmethod
  def _nonzero(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")

  def str_short(self):
    return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'

  def __hash__(self):
    return hash((self.shape, self.dtype))

  def __eq__(self, other):
    return (type(self) is type(other) and
            self.shape == other.shape and self.dtype == other.dtype)

  def __repr__(self):
    return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"

class ConcreteArray(ShapedArray):
  array_abstraction_level = 2
  val: np.ndarray

  def __init__(self, val):
    self.val = val
    self.shape = val.shape
    self.dtype = val.dtype

  @staticmethod
  def _bool(tracer):
    return bool(tracer.aval.val)

  @staticmethod
  def _nonzero(tracer):
    return bool(tracer.aval.val)

def get_aval(x):
  if isinstance(x, Tracer):
    return x.aval
  elif type(x) in jax_types:
    return ConcreteArray(np.asarray(x))
  else:
    raise TypeError(x)

jax_types = {bool, int, float,
             np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}

请注意,我们实际上为数组提供了两个AbstractValue,代表了不同的抽象级别。ShapedArray表示具有给定形状和数据类型的所有可能数组的集合。ConcreteArray表示一个由单个数组值组成的单例集合。

现在我们已经设置好了解释器栈、解释器的Trace/Tracer API以及抽象值,我们可以回来实现bind了。

def bind(prim, *args, **params):
  top_trace = find_top_trace(args)
  tracers = [full_raise(top_trace, arg) for arg in args]
  outs = top_trace.process_primitive(prim, tracers, params)
  return [full_lower(out) for out in outs]

主要操作是我们调用find_top_trace来找出哪个解释器应该处理此原始操作应用。然后我们调用该顶部跟踪器的process_primitive,以便跟踪器可以应用其解释规则。对full_raise的调用只是确保输入被封装在顶部跟踪器的Tracer实例中,而对full_lower的调用则是一个可选优化,以便我们尽可能地从Tracer中解封装值。

import operator as op

def find_top_trace(xs) -> Trace:
  top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
                 default=trace_stack[0], key=op.attrgetter('level'))
  if dynamic_trace and dynamic_trace.level > top_main.level:
    top_main = dynamic_trace
  return top_main.trace_type(top_main)

简单来说,暂时忽略第3部分中的dynamic_trace步骤,find_top_trace返回与其输入上的Tracers关联的最高级别解释器,否则返回栈底的解释器(至少目前总是求值跟踪器)。这与上述描述有所不同,上述描述中我们总是从运行栈顶的解释器开始,然后向下遍历,应用栈中的每个解释器。相反,我们只在原始绑定函数的输入参数被封装在对应于该解释器的Tracer中时才应用解释器。这种优化让我们能够跳过不相关的转换,但隐含了一个假设,即转换大多遵循数据依赖性(除了特殊的栈底解释器,它解释所有内容)。

另一种选择是让栈中的每个解释器解释每个操作。这值得探索!JAX在很大程度上围绕数据依赖性设计,因为这对于自动微分来说非常自然,JAX的根源就在于自动微分。但它可能存在过度拟合。

def full_lower(val: Any):
  if isinstance(val, Tracer):
    return val.full_lower()
  else:
    return val

def full_raise(trace: Trace, val: Any) -> Tracer:
  if not isinstance(val, Tracer):
    assert type(val) in jax_types
    return trace.pure(val)
  level = trace.main.level
  if val._trace.main is trace.main:
    return val
  elif val._trace.main.level < level:
    return trace.lift(val)
  elif val._trace.main.level > level:
    raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
  else:  # val._trace.level == level
    raise Exception(f"Different traces at same level: {val._trace}, {trace}.")

full_raise中的逻辑用于将值封装到特定TraceTracer中,并根据上下文调用Trace上的不同方法:对非Tracer常量调用Trace.pure,对已经是从较低级别解释器获得的Tracer的值调用Trace.lift。这两个方法可以共享相同的实现,但通过在核心逻辑中区分它们,我们可以向Trace子类提供更多信息。

JAX核心部分就这些了!现在我们可以开始添加解释器了。

求值解释器#

我们将从最简单的解释器开始:位于解释器栈底部的求值解释器。

class EvalTrace(Trace):
  pure = lift = lambda self, x: x  # no boxing in Tracers needed

  def process_primitive(self, primitive, tracers, params):
    return impl_rules[primitive](*tracers, **params)

trace_stack.append(MainTrace(0, EvalTrace, None))  # special bottom of the stack

# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}

impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]

def broadcast_impl(x, *, shape, axes):
  for axis in sorted(axes):
    x = np.expand_dims(x, axis)
  return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl

有了这个解释器,我们可以求值用户函数

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

print(f(3.0))
2.7177599838802657

哇!就像兜了一个大圈。但这种间接性的意义在于,现在我们可以添加一些真正的转换了。

使用jvp的前向模式自动微分#

首先,一些辅助函数

import builtins

def zeros_like(val):
  aval = get_aval(val)
  return np.zeros(aval.shape, aval.dtype)

def unzip2(pairs):
  lst1, lst2 = [], []
  for x1, x2 in pairs:
    lst1.append(x1)
    lst2.append(x2)
  return lst1, lst2

def map(f, *xs):
  return list(builtins.map(f, *xs))

def zip(*args):
  fst, *rest = args = map(list, args)
  n = len(fst)
  for arg in rest:
    assert len(arg) == n
  return list(builtins.zip(*args))

前向模式自动微分的Tracer携带一个原始-切线对。Trace应用JVP规则。

class JVPTracer(Tracer):
  def __init__(self, trace, primal, tangent):
    self._trace = trace
    self.primal = primal
    self.tangent = tangent

  @property
  def aval(self):
    return get_aval(self.primal)

class JVPTrace(Trace):
  pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))

  def process_primitive(self, primitive, tracers, params):
    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
    jvp_rule = jvp_rules[primitive]
    primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
    return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]

jvp_rules = {}

请注意,purelift都将一个值打包到一个具有最少上下文的JVPTracer中,该上下文是一个零切线值。

让我们为原语添加一些JVP规则

def add_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp

def mul_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp

def sin_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp

def cos_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp

def neg_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp

def reduce_sum_jvp(primals, tangents, *, axis):
  (x,), (x_dot,) = primals, tangents
  return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp

def greater_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = greater(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp

def less_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = less(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp

最后,我们添加一个转换API来启动跟踪

def jvp_v1(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    out = f(*tracers_in)
    tracer_out = full_raise(trace, out)
    primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
  return primal_out, tangent_out

有了这些,我们就可以进行微分了!

x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
  return lambda x: jvp_v1(f, (x,), (1.,))[1]

print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
  if x > 0.:  # Python control flow
    return 2. * x
  else:
    return x

print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0

Pytrees以及扁平化用户函数的输入和输出#

jvp_v1的一个限制是它假定用户函数接受数组作为位置参数并生成单个数组作为输出。如果它生成列表作为输出怎么办?或者接受嵌套容器作为输入?在栈的每一层处理输入和输出中所有可能的容器将是件麻烦事。相反,我们可以包装用户函数,使得包装后的版本接受数组作为输入并返回一个扁平的数组列表作为输出。包装器只需要解扁平化其输入,调用用户函数,然后扁平化输出。

假设用户总是给我们提供接受数组作为输入并生成扁平数组列表作为输出的函数,我们希望这样编写jvp

def jvp_flat(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
  return primals_out, tangents_out

为了支持输入和输出中包含任意容器的用户函数,我们这样编写面向用户的jvp包装器

def jvp(f, primals, tangents):
  primals_flat, in_tree = tree_flatten(primals)
  tangents_flat, in_tree2 = tree_flatten(tangents)
  if in_tree != in_tree2: raise TypeError
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)
  tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
  return primals_out, tangents_out

请注意,我们必须将用户函数输出的树结构传递回flatten_fun的调用者。该信息直到我们实际运行用户函数后才可用,因此flatten_fun只是返回一个可变单元的引用,表示为一个thunk。这些副作用是安全的,因为我们总是只运行用户函数一次。(这种安全机制是linear_util.py中使用“线性”名称的原因,这是线性类型的含义。)

剩下的就是编写tree_flattentree_unflattenflatten_fun了。

隐藏代码单元格源

def flatten_fun(f, in_tree):
  store = Store()

  def flat_fun(*args_flat):
    pytree_args = tree_unflatten(in_tree, args_flat)
    out = f(*pytree_args)
    out_flat, out_tree = tree_flatten(out)
    store.set_value(out_tree)
    return out_flat

  return flat_fun, store

class Empty: pass
empty = Empty()

class Store:
  val = empty

  def set_value(self, val):
    assert self.val is empty
    self.val = val

  def __call__(self):
    return self.val

隐藏代码单元格源

from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from collections.abc import Callable

class NodeType(NamedTuple):
  name: str
  to_iterable: Callable
  from_iterable: Callable

def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
                         ) -> None:
  node_types[ty] = NodeType(str(ty), to_iter, from_iter)

node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list,  lambda l: (None, l), lambda _, xs:  list(xs))
register_pytree_node(dict,
                     lambda d: map(tuple, unzip2(sorted(d.items()))),
                     lambda keys, vals: dict(zip(keys, vals)))

class PyTreeDef(NamedTuple):
  node_type: NodeType
  node_metadata: Hashable
  child_treedefs: tuple['PyTreeDef', ...]

class Leaf: pass
leaf = Leaf()

def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
  children_iter, treedef = _tree_flatten(x)
  return list(children_iter), treedef

def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
  node_type = node_types.get(type(x))
  if node_type:
    node_metadata, children = node_type.to_iterable(x)
    children_flat, child_trees = unzip2(map(_tree_flatten, children))
    flattened = it.chain.from_iterable(children_flat)
    return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
  else:
    return [x], leaf

def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
  return _tree_unflatten(treedef, iter(xs))

def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
  if treedef is leaf:
    return next(xs)
  else:
    children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
    return treedef.node_type.from_iterable(treedef.node_metadata, children)

有了这个处理pytree的jvp实现,我们现在可以处理任意的输入和输出容器了。这在未来的转换中也会派上用场!

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return {'hi': z, 'there': [x, y]}

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}
{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]}

使用vmap的向量化批处理#

首先,几个辅助函数,一个用于从非映射值生成映射抽象值(通过移除一个轴),另一个用于移动批处理维度

def mapped_aval(batch_dim, aval):
  shape = list(aval.shape)
  del shape[batch_dim]
  return ShapedArray(tuple(shape), aval.dtype)

def move_batch_axis(axis_size, src, dst, x):
  if src is not_mapped:
    target_shape = list(np.shape(x))
    target_shape.insert(dst, axis_size)
    return broadcast(x, target_shape, [dst])
  elif src == dst:
    return x
  else:
    return moveaxis(x, src, dst)

def moveaxis(x, src: int, dst: int):
  perm = [i for i in range(np.ndim(x)) if i != src]
  perm.insert(dst, src)
  return transpose(x, perm)

向量化批处理的Tracer携带一个批处理值和一个可选的整数,指示哪个轴(如果有的话)是批处理轴。

from typing import Union

class NotMapped: pass
not_mapped = NotMapped()

BatchAxis = Union[NotMapped, int]

class BatchTracer(Tracer):
  def __init__(self, trace, val, batch_dim: BatchAxis):
    self._trace = trace
    self.val = val
    self.batch_dim = batch_dim

  @property
  def aval(self):
    if self.batch_dim is not_mapped:
      return get_aval(self.val)
    else:
      return mapped_aval(self.batch_dim, get_aval(self.val))

  def full_lower(self):
    if self.batch_dim is not_mapped:
      return full_lower(self.val)
    else:
      return self

class BatchTrace(Trace):
  pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)

  def process_primitive(self, primitive, tracers, params):
    vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
    vmap_rule = vmap_rules[primitive]
    val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
    return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]

  @property
  def axis_size(self):
    return self.main.global_data

vmap_rules = {}

在这里,我们实现了可选的Tracer.full_lower方法,如果不需要批处理跟踪器,因为它不表示批处理值,该方法允许我们将其剥离。

对于BatchTrace,类似于JVPTracepurelift方法只是将一个值封装到具有最少上下文的BatchTracer中,在本例中,上下文是一个采用哨兵值not_mappedbatch_dim。请注意,我们使用MainTrace的解释器全局数据字段来存储批处理轴大小。

接下来我们可以为每个原语定义批处理解释器规则

from functools import partial

def binop_batching_rule(op, axis_size, vals_in, dims_in):
  (x, y), (x_bdim, y_bdim) = vals_in, dims_in
  if x_bdim != y_bdim:
    if x_bdim is not_mapped:
      x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
      x_bdim = y_bdim
    else:
      y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
  return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)

def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
  (x,), (x_bdim,) = vals_in, dims_in
  return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)

def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
  (x,), (x_bdim,) = vals_in, dims_in
  new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
  out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
  return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule

最后,我们添加一个转换API来启动跟踪

def vmap_flat(f, in_axes, *args):
  axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
                if ax is not not_mapped}
  with new_main(BatchTrace, axis_size) as main:
    trace = BatchTrace(main)
    tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
                  for x, ax in zip(args, in_axes)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
  outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
                     for val_out, bdim in zip(vals_out, bdims_out)]
  return outs_transposed

def vmap(f, in_axes):
  def batched_f(*args):
    args_flat, in_tree = tree_flatten(args)
    in_axes_flat, in_tree2 = tree_flatten(in_axes)
    if in_tree != in_tree2: raise TypeError
    f_flat, out_tree = flatten_fun(f, in_tree)
    outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
    return tree_unflatten(out_tree(), outs_flat)
  return batched_f
def add_one_to_a_scalar(scalar):
  assert np.ndim(scalar) == 0
  return 1 + scalar

vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)

print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
  pushfwd = lambda v: jvp(f, (x,), (v,))[1]
  vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
  return vmap(pushfwd, (0,))(vecs_in)

def f(x):
  return sin(x)

jacfwd(f, np.arange(3.))
array([[ 1.        ,  0.        , -0.        ],
       [ 0.        ,  0.54030231, -0.        ],
       [ 0.        ,  0.        , -0.41614684]])

jvpvmap就这些了!

第2部分:Jaxprs#

接下来的转换是用于即时编译的jit和用于反向模式自动微分的vjp。(grad只是vjp的一个小包装。)jvpvmap只需要每个Tracer携带少量额外上下文,而对于jitvjp,我们需要更丰富的上下文:我们需要表示程序。也就是说,我们需要jaxprs!

Jaxprs是JAX内部的程序中间表示。它们是显式类型化、函数式、一阶的,并采用ANF形式。我们需要jit的程序表示,因为jit的目的是将计算从Python中阶段化(stage out)。对于任何我们想要阶段化的计算,我们需要能够将其表示为数据,并在跟踪Python函数时构建它。类似地,vjp需要一种方式来表示反向模式自动微分的后向传播计算。我们对这两种需求使用相同的jaxpr程序表示。

(构建程序表示是最自由的一种跟踪转换,因此除了处理原生Python控制流的问题外,任何转换都可以通过首先跟踪到jaxpr然后解释jaxpr来实现。)

Jaxpr数据结构#

jaxpr项的语法大致如下

jaxpr ::=
  { lambda <binder> , ... .
    let <eqn>
        ...
    in ( <atom> , ... ) }

binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>

eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...

类型的语法是

jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...

我们如何将这些表示为Python数据结构?我们重用ShapedArrays来表示类型,并且可以用几个Python结构体来表示术语语法

class Var:
  aval: ShapedArray
  def __init__(self, aval): self.aval = aval

class Lit:
  val: Any
  aval: ShapedArray

  def __init__(self, val):
    self.aval = aval = raise_to_shaped(get_aval(val))
    self.val = np.array(val, aval.dtype)

Atom = Union[Var, Lit]

class JaxprEqn(NamedTuple):
  primitive: Primitive
  inputs: list[Atom]
  params: dict[str, Any]
  out_binders: list[Var]

class Jaxpr(NamedTuple):
  in_binders: list[Var]
  eqns: list[JaxprEqn]
  outs: list[Atom]

  def __hash__(self): return id(self)
  __eq__ = op.is_

def raise_to_shaped(aval):
  return ShapedArray(aval.shape, aval.dtype)

类型检查jaxpr包括检查没有未绑定变量,变量只绑定一次,以及每个方程中原始应用程序的类型与输出绑定器的类型匹配。

class JaxprType(NamedTuple):
  in_types:  list[ShapedArray]
  out_types: list[ShapedArray]

  def __repr__(self):
    in_types = ', '.join(aval.str_short() for aval in self.in_types)
    out_types = ', '.join(aval.str_short() for aval in self.out_types)
    return f'({in_types}) -> ({out_types})'

def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
  env: set[Var] = set()

  for v in jaxpr.in_binders:
    if v in env: raise TypeError
    env.add(v)

  for eqn in jaxpr.eqns:
    in_types = [typecheck_atom(env, x) for x in eqn.inputs]
    out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
    for out_binder, out_type in zip(eqn.out_binders, out_types):
      if not out_type == out_binder.aval: raise TypeError
    for out_binder in eqn.out_binders:
      if out_binder in env: raise TypeError
      env.add(out_binder)

  in_types = [v.aval for v in jaxpr.in_binders]
  out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
  return JaxprType(in_types, out_types)

def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
  if isinstance(x, Var):
    if x not in env: raise TypeError("unbound variable")
    return x.aval
  elif isinstance(x, Lit):
    return raise_to_shaped(get_aval(x.val))
  else:
    assert False

我们可以使用一个简单的解释器将由jaxpr表示的函数应用于参数。

def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
  env: dict[Var, Any] = {}

  def read(x: Atom) -> Any:
    return env[x] if type(x) is Var else x.val

  def write(v: Var, val: Any) -> None:
    assert v not in env  # single-assignment
    env[v] = val

  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_vals = map(read, eqn.inputs)
    outs = bind(eqn.primitive, *in_vals, **eqn.params)
    map(write, eqn.out_binders, outs)
  return map(read, jaxpr.outs)

def jaxpr_as_fun(jaxpr: Jaxpr):
  return lambda *args: eval_jaxpr(jaxpr, args)

通过在解释器中使用bind,这个解释器本身是可跟踪的。

通过跟踪构建Jaxprs#

既然我们有了作为数据结构的jaxprs,就需要从跟踪Python代码中生成它们的方法。通常,我们有两种跟踪到jaxpr的变体;jit使用其中一种,vjp使用另一种。我们将从jit使用的方法开始,该方法也用于控制流原语,如lax.condlax.while_looplax.scan

def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
  assert 0 <= n <= len(lst)
  return lst[:n], lst[n:]

def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
  assert len(bs) == len(l)
  lists = lst1, lst2 = [], []
  for b, x in zip(bs, l):
    lists[b].append(x)
  return lst1, lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
  __slots__ = ['aval']
  aval: ShapedArray

  def __init__(self, trace, aval):
    self._trace = trace
    self.aval = aval

# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
  def new_arg(self, aval: ShapedArray) -> JaxprTracer:
    aval = raise_to_shaped(aval)
    tracer = self.builder.new_tracer(self, aval)
    self.builder.tracer_to_var[id(tracer)] = Var(aval)
    return tracer

  def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
    tracer = self.builder.const_tracers.get(id(val))
    if tracer is None:
      tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
      self.builder.add_const(tracer, val)
    return tracer
  pure = lift = get_or_make_const_tracer

  def process_primitive(self, primitive, tracers, params):
    avals_in = [t.aval for t in tracers]
    avals_out = abstract_eval_rules[primitive](*avals_in, **params)
    out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
    inputs = [self.builder.getvar(t) for t in tracers]
    outvars = [self.builder.add_var(t) for t in out_tracers]
    self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
    return out_tracers

  @property
  def builder(self):
    return self.main.global_data

# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {}

请注意,我们将一个构建器对象作为解释器全局数据保留下来,该对象在我们构建jaxpr时跟踪变量、常量和方程。

class JaxprBuilder:
  eqns: list[JaxprEqn]
  tracer_to_var: dict[int, Var]
  const_tracers: dict[int, JaxprTracer]
  constvals: dict[Var, Any]
  tracers: list[JaxprTracer]

  def __init__(self):
    self.eqns = []
    self.tracer_to_var = {}
    self.const_tracers = {}
    self.constvals = {}
    self.tracers = []

  def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
    tracer = JaxprTracer(trace, aval)
    self.tracers.append(tracer)
    return tracer

  def add_eqn(self, eqn: JaxprEqn) -> None:
    self.eqns.append(eqn)

  def add_var(self, tracer: JaxprTracer) -> Var:
    assert id(tracer) not in self.tracer_to_var
    var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
    return var

  def getvar(self, tracer: JaxprTracer) -> Var:
    var = self.tracer_to_var.get(id(tracer))
    assert var is not None
    return var

  def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
    var = self.add_var(tracer)
    self.const_tracers[id(val)] = tracer
    self.constvals[var] = val
    return var

  def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
            ) -> tuple[Jaxpr, list[Any]]:
    constvars, constvals = unzip2(self.constvals.items())
    t2v = lambda t: self.tracer_to_var[id(t)]
    in_binders = constvars + [t2v(t) for t in in_tracers]
    out_vars = [t2v(t) for t in out_tracers]
    jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
    typecheck_jaxpr(jaxpr)
    jaxpr, constvals = _inline_literals(jaxpr, constvals)
    return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
  const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
  scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
  new_const_binders, lit_binders = partition_list(scalars, const_binders)
  new_consts, lit_vals = partition_list(scalars, consts)
  literals = dict(zip(lit_binders, map(Lit, lit_vals)))
  new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
                       eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
  new_outs = [literals.get(x, x) for x in jaxpr.outs]
  new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
  typecheck_jaxpr(new_jaxpr)
  return new_jaxpr, new_consts

我们为JaxprTrace.process_primitive所需的规则本质上是原始应用程序的类型规则:给定原始操作、其参数以及输入的类型,该规则必须生成一个输出类型,然后该输出类型与输出JaxprTracer一起打包。我们可以将抽象求值规则用于相同的目的,尽管它们可能更通用(因为抽象求值规则必须接受ConcreteArray输入,并且只需要返回可能输出集的上限,它们也可以生成ConcreteArray输出)。我们将重用这些抽象求值规则用于其他生成jaxpr的跟踪机制,在这些机制中,潜在的额外通用性非常有用。

def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
  return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval

def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if x.shape != y.shape: raise TypeError
  return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval

def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
  return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval

def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
                             ) -> list[ShapedArray]:
  axis_ = set(axis)
  new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
  return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval

def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
                            axes: Sequence[int]) -> list[ShapedArray]:
  return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval

为了检查我们的jaxprs实现,我们可以添加一个make_jaxpr转换和一个美化打印器

from functools import lru_cache

@lru_cache  # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)

  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    trace = JaxprTrace(main)
    tracers_in = [trace.new_arg(aval) for aval in avals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree()

隐藏代码单元格源

from collections import defaultdict
import string

class PPrint:
  lines: list[tuple[int, str]]

  def __init__(self, lines):
    self.lines = lines

  def indent(self, indent: int) -> 'PPrint':
    return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])

  def __add__(self, rhs: 'PPrint') -> 'PPrint':
    return PPrint(self.lines + rhs.lines)

  def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
    if not rhs.lines: return self
    if not self.lines: return rhs
    indent, s = self.lines[-1]
    indented_block = rhs.indent(indent + len(s))
    common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
    return PPrint(self.lines[:-1]
                  + [(indent, common_line)]
                  + indented_block.lines[1:])

  def __str__(self) -> str:
    return '\n'.join(' ' * indent + s for indent, s in self.lines)

def pp(s: Any) -> PPrint:
  return PPrint([(0, line) for line in str(s).splitlines()])

def vcat(ps: list[PPrint]) -> PPrint:
  return sum(ps, pp(''))

def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
  namegen = (''.join(s) for r in it.count(1)
             for s in it.permutations(string.ascii_lowercase, r))
  names = defaultdict(lambda: next(namegen))
  in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
  eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
  outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
                   for v in jaxpr.outs)
  return (pp(f'{{ lambda {in_binders} .') +
          ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))

def var_str(names: defaultdict[Var, str], v: Var) -> str:
  return f'{names[v]}:{v.aval.str_short()}'

def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  rule = pp_rules.get(eqn.primitive)
  if rule:
    return rule(names, eqn)
  else:
    lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
    rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
           pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                       for x in eqn.inputs)))
    return lhs >> pp(' = ') >> rhs

def pp_params(params: dict[str, Any]) -> PPrint:
  items = sorted(params.items())
  if items:
    return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
  else:
    return pp(' ')

Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {}
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
  let b:float64[] = mul 2.0 a
  in ( b ) }
(float64[]) -> (float64[])

但这里有一个限制:由于find_top_trace通过数据依赖性操作,make_jaxpr_v1无法阶段化其给定Python可调用对象执行的所有原始操作。例如

jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda  .
  let 
  in ( 4.0 ) }

这正是omnistaging解决的问题。我们希望确保由make_jaxpr启动的JaxprTrace始终被应用,无论bind的任何输入是否封装在相应的JaxprTracer实例中。我们可以通过使用第1部分中定义的dynamic_trace全局变量来实现这一点。

@contextmanager
def new_dynamic(main: MainTrace):
  global dynamic_trace
  prev_dynamic_trace, dynamic_trace = dynamic_trace, main
  try:
    yield
  finally:
    dynamic_trace = prev_dynamic_trace

@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
               ) -> tuple[Jaxpr, list[Any], PyTreeDef]:
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)

  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    with new_dynamic(main):
      trace = JaxprTrace(main)
      tracers_in = [trace.new_arg(aval) for aval in avals_in]
      outs = f(*tracers_in)
      tracers_out = [full_raise(trace, out) for out in outs]
      jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree()

jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda  .
  let a:float64[] = mul 2.0 2.0
  in ( a ) }

以这种方式使用dynamic_trace在概念上等同于暂存当前解释器栈并从底部以JaxprTrace开始一个新栈。也就是说,栈中低于dynamic_trace的解释器都不会被应用(因为JaxprTrace.process_primitive不调用bind),不过如果被跟踪到jaxpr的Python可调用对象本身使用了转换,那么这些转换可以被推到JaxprTrace之上的解释器栈中。但是,暂时暂存解释器栈会破坏系统状态。dynamic_trace标签在保持系统状态更简单的同时实现了相同目标。

Jaxprs部分就这些了!有了jaxprs,我们就可以实现JAX剩下的主要功能了。

第3部分:简化的jit#

虽然jit具有类似转换的API,因为它接受一个Python可调用对象作为参数,但其内部实际上是一个高阶原语,而不是一个转换。当一个原语由函数参数化时,它就是高阶的

即时(“最终风格”)和分阶段(“初始风格”)处理#

处理高阶原语有两种选择。每种方法都需要不同的跟踪方式并带来不同的权衡。

  1. 即时处理,其中bind接受一个Python可调用对象作为参数。 我们尽可能晚地推迟jaxpr的形成,即直到我们在解释器栈底部运行最终解释器。这样,我们可以在解释器栈底部交换一个JaxprTrace,从而阶段化而不是执行所有原始操作。使用这种方法,栈中的转换在我们执行Python可调用对象时照常应用。这种方法实现起来可能非常棘手,但它尽可能通用,因为它允许高阶原语不提升其参数的抽象级别,从而允许数据依赖的Python控制流。我们将这种方法称为使用“最终风格高阶原语”,它采用我们在迄今为止使用过的“跟踪时生效的最终风格转换”。

  2. 分阶段处理,其中bind接受jaxpr作为参数。 在我们调用bind之前,在原始包装器中,我们可以直接使用make_jaxpr预先形成一个jaxpr,并完全处理完Python可调用对象。在这种情况下,make_jaxpr将其JaxprTrace放在解释器栈的顶部,并且栈中较低层的转换(可能通过闭包跟踪器进入)在跟踪Python可调用对象时不会应用。 (在Python可调用对象内部应用的转换照常应用,并添加到JaxprTrace上方的栈中。)相反,栈中较低层的转换稍后应用于调用原语,并且调用原语的规则必须转换jaxpr本身。由于我们预先跟踪到jaxpr,这种方法不支持数据依赖的Python控制流,但实现起来更直接。我们将这种高阶原语称为“初始风格高阶原语”,并称其jaxpr处理转换规则为“初始风格转换规则”。

后一种方法适用于jit,因为我们不需要在用户提供的Python可调用对象中支持数据依赖的Python控制流,因为jit的整个目的是将计算从Python中阶段化出来,由XLA执行。(相比之下,custom_jvp是一个高阶原语,我们希望在其内部支持数据依赖的Python控制流。)

从历史上看,我们在阅读了typed tagless final interpreters这篇论文后,开始使用“初始风格”和“最终风格”的术语,并开玩笑地将JAX称为“无类型带标签最终解释器”的实现。我们不声称这些术语背后有任何深刻的含义(或理解),我们粗略地使用“初始风格”来表示“构建AST然后对其进行转换”,而使用“最终风格”来表示“在我们跟踪时进行转换”。但这只是不精确但又很流行的行话。

采用初始风格方法,这是面向用户的jit包装器

def jit(f):
  def f_jitted(*args):
    avals_in = [raise_to_shaped(get_aval(x)) for x in args]
    jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
    outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
    return tree_unflatten(out_tree, outs)
  return f_jitted

xla_call_p = Primitive('xla_call')

对于任何新的原语,我们需要为其提供转换规则,从其求值规则开始。当我们求值xla_call原语的应用时,我们希望将计算阶段化到XLA。这涉及将jaxpr转换为XLA HLO程序,将参数值传输到XLA设备,执行XLA程序,然后将结果传回。我们将缓存XLA HLO编译,以便对于每个jit函数,它只需要对每个参数形状和dtype签名执行一次。

首先,一些实用工具。

class IDHashable:
  val: Any

  def __init__(self, val):
    self.val = val

  def __hash__(self) -> int:
    return id(self.val)

  def __eq__(self, other):
    return type(other) is IDHashable and id(self.val) == id(other.val)

接下来,我们将定义xla_call的求值规则

import io
from jax.extend.mlir import ir
from jax.extend.mlir.dialects import func
from jax.extend.mlir.dialects import stablehlo as hlo
from jax._src import xla_bridge as xb

class MlirContext(NamedTuple):
  module: ir.Module
  symbol_table: ir.SymbolTable

def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
  consts, args = args[:num_consts], args[num_consts:]
  hashable_consts = tuple(map(IDHashable, consts))
  execute = xla_callable(IDHashable(jaxpr), hashable_consts)
  return execute(*args)
impl_rules[xla_call_p] = xla_call_impl

@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
                 hashable_consts: tuple[IDHashable, ...]):
  jaxpr: Jaxpr = hashable_jaxpr.val
  typecheck_jaxpr(jaxpr)
  consts = [x.val for x in hashable_consts]
  in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]

  with ir.Context() as ctx, ir.Location.unknown(ctx):
    hlo.register_dialect(ctx)
    m = ir.Module.create()
    c = MlirContext(m, ir.SymbolTable(m.operation))

    with ir.InsertionPoint(c.module.body):
      @func.func(*(aval_to_ir_type(aval) for aval in in_avals))
      def main(*params):
        return jaxpr_subcomp(c, jaxpr, _hlo_consts(consts) + params)

  output = io.StringIO()
  c.module.operation.print(file=output)
  backend = xb.get_backend(None)
  compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1])
  return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])

def _mlir_dtype(dtype: np.dtype) -> ir.Type:
  if np.issubdtype(dtype, np.signedinteger):
    return ir.IntegerType.get_signless(np.iinfo(dtype).bits)
  elif dtype == np.float32:
    return ir.F32Type.get()
  elif dtype == np.float64:
    return ir.F64Type.get()
  else:
    raise NotImplementedError("MLIR conversion not implemented for ", dtype)

def aval_to_ir_type(aval: ShapedArray) -> ir.Type:
  return ir.RankedTensorType.get(aval.shape, _mlir_dtype(aval.dtype))

def _hlo_const(x: Any) -> ir.Value:
  a = np.asarray(x)
  if a.dtype == np.bool_:
    return hlo.constant(ir.DenseElementsAttr.get(
      np.packbits(a, bitorder='little'), type=ir.IntegerType.get_signless(1),
      shape=a.shape))
  else:
    return hlo.constant(ir.DenseElementsAttr.get(a))

def _hlo_consts(consts: list[Any]) -> list[ir.Value]:
  unique_consts = {id(cnst): cnst for cnst in consts}
  ir_consts = {id_: _hlo_const(cnst) for id_, cnst in unique_consts.items()}
  return tuple(ir_consts[id(cnst)] for cnst in consts)

主要操作在xla_callable中,它使用jaxpr_subcomp将jaxpr编译成XLA HLO程序,然后返回一个可调用对象来执行编译后的程序。

def jaxpr_subcomp(c: MlirContext, jaxpr: Jaxpr, args: list[ir.Value]) -> list[ir.Value]:
  env: dict[Var, ir.Value] = {}

  def read(x: Atom) -> ir.Value:
    return env[x] if type(x) is Var else _hlo_const(np.asarray(x.val))

  def write(v: Var, val: ir.Value) -> None:
    env[v] = val

  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_avals = [x.aval for x in eqn.inputs]
    in_vals = map(read, eqn.inputs)
    out_avals = [x.aval for x in eqn.out_binders]
    rule = hlo_translations[eqn.primitive]
    assert all(isinstance(v, ir.Value) for v in in_vals), in_vals
    out_vals = rule(c, in_avals, out_avals, in_vals, **eqn.params)
    assert all(isinstance(v, ir.Value) for v in out_vals), out_vals
    map(write, eqn.out_binders, out_vals), out_vals
  return map(read, jaxpr.outs)

def execute_compiled(compiled, out_avals, *args):
  input_bufs = [input_handlers[type(x)](x) for x in args]
  out_bufs = compiled.execute(input_bufs)
  return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]

default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
                  [bool, int, float, np.ndarray, np.float64, np.float32]}

def handle_result(aval: ShapedArray, buf):
  del aval  # Unused for now
  return np.asarray(buf)

hlo_translations = {}

请注意,jaxpr_subcomp具有简单解释器的结构。这是一个常见的模式:我们处理jaxprs的方式通常是使用解释器。而且,与任何解释器一样,我们需要为每个原语制定解释规则。

def direct_translation(op, c, in_avals, out_avals, in_vals):
  del c, in_avals, out_avals
  return [op(*in_vals)]

hlo_translations[add_p] = partial(direct_translation, hlo.add)
hlo_translations[mul_p] = partial(direct_translation, hlo.multiply)
hlo_translations[neg_p] = partial(direct_translation, hlo.negate)
hlo_translations[sin_p] = partial(direct_translation, hlo.sine)
hlo_translations[cos_p] = partial(direct_translation, hlo.cosine)

def compare_translation(op, c, in_avals, out_avals, in_vals):
  del c, out_avals
  return [hlo.compare(*in_vals, hlo.ComparisonDirectionAttr.get(op))]

hlo_translations[greater_p] = partial(compare_translation, "GT")
hlo_translations[less_p] = partial(compare_translation, "LT")

def reduce_sum_translation(c, in_avals, out_avals, in_vals, *, axis):
  del c
  (x_aval,), (out_aval,), (x,) = in_avals, out_avals, in_vals
  op = hlo.ReduceOp(
    [aval_to_ir_type(out_aval)], [x], [_hlo_const(np.array(0, x_aval.dtype))],
    axis)
  scalar_type = aval_to_ir_type(ShapedArray((), x_aval.dtype))
  reducer_region = op.body.blocks.append(scalar_type, scalar_type)
  with ir.InsertionPoint(reducer_region):
    hlo.return_([hlo.add(*reducer_region.arguments)])
  return op.results

hlo_translations[reduce_sum_p] = reduce_sum_translation

def broadcast_translation(c, in_avals, out_avals, in_vals, *, shape, axes):
  del c
  (x,), (out_aval,) = in_vals, out_avals
  dims_complement = [i for i in range(len(shape)) if i not in axes]
  return [hlo.broadcast_in_dim(aval_to_ir_type(out_aval), x, dims_complement)]
hlo_translations[broadcast_p] = broadcast_translation

有了这些,我们现在就可以使用jit来阶段化、编译和执行XLA程序了!

@jit
def f(x, y):
  print('tracing!')
  return sin(x) * cos(y)
z = f(3., 4.)  # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.)  # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306993
@jit
def f(x):
  return reduce_sum(x, axis=0)

print(f(np.array([1., 2., 3.])))
6.0
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

def deriv(f):
  return lambda x: jvp(f, (x,), (1.,))[1]

print(    deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344

与其先将jit实现为跟踪到jaxpr,然后再将jaxpr降低到XLA HLO,可能看起来我们可以跳过jaxpr步骤,直接在跟踪时降低到HLO。也就是说,也许我们本可以使用TraceTracer来实现jit,它们在每次原始绑定时递增地附加到XLA HLO图。这目前是正确的,但当我们引入编译后的SPMD计算时将不可能,因为那时我们必须在编译程序之前知道所需的副本数量。

除了其求值规则外,我们尚未为xla_call_p定义任何转换规则。也就是说,我们还不能进行vmap-of-jitjvp-of-jit,甚至jit-of-jit。相反,jit必须处于“顶层”。让我们来解决这个问题!

def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
  del num_consts  # Unused
  new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
  outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  n = len(outs) // 2
  primals_out, tangents_out = outs[:n], outs[n:]
  return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule

@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
  def jvp_traceable(*primals_and_tangents):
    n = len(primals_and_tangents) // 2
    primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
    return jvp(jaxpr_as_fun(jaxpr), primals, tangents)

  in_avals = [v.aval for v in jaxpr.in_binders]
  new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
  return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
  del num_consts  # Unused
  new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
  outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule

@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
               ) -> tuple[Jaxpr, list[Any]]:
  vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
  in_avals = [unmapped_aval(axis_size, d, v.aval)
              for v, d in zip(jaxpr.in_binders, bdims_in)]
  new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
  return new_jaxpr, new_consts

def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
                  ) -> ShapedArray:
  if batch_dim is not_mapped:
    return aval
  else:
    shape = list(aval.shape)
    shape.insert(batch_dim, axis_size)
    return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
  del num_consts  # Unused
  jaxpr_type = typecheck_jaxpr(jaxpr)
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule

def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts):
  del num_consts, out_avals
  # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
  with ir.InsertionPoint(c.module.body):
    @func.func(*(aval_to_ir_type(aval) for aval in in_avals))
    def inner_xla_call(*params):
      return jaxpr_subcomp(c, jaxpr, params)
    name = c.symbol_table.insert(inner_xla_call.func_op)
  return func.CallOp(inner_xla_call.func_op, in_vals).results
hlo_translations[xla_call_p] = xla_call_translation
@jit
def f(x):
  print('tracing!')
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,))  # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0.         -0.68294197  0.18140515]

缺少的一部分是数组的设备内存持久化。也就是说,我们定义了handle_result将结果作为NumPy数组传输回CPU内存,但通常最好避免仅仅为了下一次操作而传输结果再传回。我们可以通过引入一个Array类来实现这一点,该类可以包装XLA缓冲区,并以其他方式模仿numpy.ndarray

def handle_result(aval: ShapedArray, buf):  # noqa: F811
  return Array(aval, buf)

class Array:
  buf: Any
  aval: ShapedArray

  def __init__(self, aval, buf):
    self.aval = aval
    self.buf = buf

  dtype = property(lambda self: self.aval.dtype)
  shape = property(lambda self: self.aval.shape)
  ndim  = property(lambda self: self.aval.ndim)

  def __array__(self): return np.asarray(self.buf)
  def __repr__(self):  return repr(np.asarray(self.buf))
  def __str__(self):   return str(np.asarray(self.buf))

  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(add)
  _mul = staticmethod(mul)
  _rmul = staticmethod(mul)
  _gt = staticmethod(greater)
  _lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf

jax_types.add(Array)
@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891

隐藏代码单元格源

def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
  params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
  rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                     for x in eqn.inputs)))
  return vcat([lhs >> pp(' = ') >> rhs,
               pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call

第4部分:linearizevjp(以及grad!)#

linearizevjp自动微分函数建立在jvp之上,但也涉及jaxprs。这是因为它们都涉及计算的阶段化或延迟。

linearize#

对于linearize,我们希望阶段化jvp计算的线性部分。也就是说,以Haskell风格的类型签名为例,如果我们有jvp : (a -> b) -> (a, T a) -> (b, T b),那么我们编写linearize : (a -> b) -> a -> (b, T a -o T b),其中T a表示“a的切线类型”,并且使用“棒棒糖”符号-o而不是箭头->来表示一个线性函数。我们也将linearize的语义定义为jvp的函数。

y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)

(y, y_dot)提供相同的结果,如同

y, y_dot = jvp(f, (x,), (x_dot,))

其中f_lin的应用不会重复任何线性化工作。我们将延迟的线性部分f_lin : T a -o T b表示为一个jaxpr。

顺带一提,既然我们有了线性箭头-o,我们就可以为jvp提供一个稍微更具信息量的类型

jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)

在这里,我们写UnrestrictedUse只是为了表明我们有一个特殊的对,其中第一个元素可以以无限制(非线性)的方式使用。结合线性箭头,这个符号仅仅是为了表达函数jvp f以非线性方式使用其第一个输入,但以线性方式使用其第二个输入,从而产生一个相应的非线性输出(可以以非线性方式使用),并与一个线性输出配对。这种更精炼的类型签名编码了jvp f中的数据依赖关系,这对于部分求值很有用。

为了从JVP构建f_lin jaxpr,我们需要执行部分求值:我们在跟踪时求值所有原始值,但将切线计算阶段化到一个jaxpr中。这是我们构建jaxprs的第二种方式。但是,make_jaxpr及其底层JaxprTrace/JaxprTracer解释器旨在阶段化每个原始绑定,而这第二种方法只阶段化那些对切线输入具有数据依赖性的原始绑定。

首先,一些实用工具

def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
  assert not len(lst) % 2
  return split_list(lst, len(lst) // 2)

def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
  l1, l2 = iter(l1), iter(l2)
  out = [next(l2) if b else next(l1) for b in which]
  assert next(l1, None) is next(l2, None) is None
  return out

接下来,我们将通过结合jvp和一个通用的部分求值转换(将在下一步添加)来编写linearize

def linearize_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
  return primals_out, f_lin

def linearize(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_lin(*tangents_in):
    tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
    if in_tree != in_tree2: raise TypeError
    tangents_out_flat = f_lin_flat(*tangents_in_flat)
    return tree_unflatten(out_tree(), tangents_out_flat)

  return primals_out, f_lin

def vspace(aval: ShapedArray) -> ShapedArray:
  return raise_to_shaped(aval)  # TODO handle integers?

现在我们转向通用的部分求值转换。目标是接受一个Python可调用对象和一组输入(其中一些已知,一些未知),并生成(1)所有可从已知输入计算出的输出,以及(2)一个jaxpr,表示Python可调用对象中只有在其余输入已知后才能执行的计算部分。

这种转换很难用类型签名来概括。如果我们假设输入函数的类型签名是(a1, a2) -> (b1, b2),其中a1a2分别代表已知输入和未知输入,并且b1只依赖于a1,而b2依赖于a2,那么我们可能会写

partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)

换句话说,给定类型为a1的输入值,partial_eval会生成类型为b1的输出,以及表示在第二阶段完成计算所需的中间值的存在量化类型r的“残余”值。它还会生成一个类型为(r, a2) -> b2的函数,该函数接受残余值以及其余输入并生成其余输出。

我们喜欢将部分求值视为将一个计算“解压缩”成两个。例如,考虑这个jaxpr

{ lambda a:float64[] .
  let b:float64[] = sin a
      c:float64[] = neg b
  in ( c ) }

JVP的jaxpr看起来像

{ lambda a:float64[] b:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      e:float64[] = mul d b
      f:float64[] = neg c
      g:float64[] = neg e
  in ( f, g ) }

如果我们想象将部分求值应用于这个jaxpr,其中第一个输入已知而第二个输入未知,我们最终会将JVP jaxpr“解压缩”为原始jaxpr和切线jaxpr

{ lambda a:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      f:float64[] = neg c
  in ( f, d ) }
{ lambda d:float64[] b:float64[] .
  let e:float64[] = mul d b
      g:float64[] = neg e
  in ( g ) }

这个第二个jaxpr代表了我们从linearize获得的线性计算。

然而,与这个jaxpr示例不同的是,我们希望在求值输入Python可调用对象时执行已知值上的计算。也就是说,我们不希望为整个函数(a1, a2) -> (b1, b2)形成jaxpr,在确定哪些可以立即求值哪些必须延迟之前,首先将所有操作从Python中阶段化,我们只希望为那些由于依赖未知输入而必须延迟的操作形成jaxpr。在自动微分的背景下,正是这个特性最终使我们能够处理像grad(lambda x: x**2 if x > 0 else 0.)这样的函数。Python控制流之所以有效,是因为部分求值将原始计算保留在Python中。因此,我们的TraceTracer子类必须即时地分辨出哪些可以求值,哪些必须阶段化到jaxpr中。

首先,我们从PartialVal类开始,它表示一个值,可以是已知或未知。

class PartialVal(NamedTuple):
  aval: ShapedArray
  const: Any | None

  @classmethod
  def known(cls, val: Any):
    return PartialVal(get_aval(val), val)

  @classmethod
  def unknown(cls, aval: ShapedArray):
    return PartialVal(aval, None)

  is_known   = property(lambda self: self.const is not None)
  is_unknown = property(lambda self: self.const is     None)

部分求值将接受一个表示输入的PartialVal列表,并返回一个PartialVal输出列表以及一个表示延迟计算的jaxpr。

def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
                      ) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
  with new_main(PartialEvalTrace) as main:
    trace = PartialEvalTrace(main)
    tracers_in = [trace.new_arg(pval) for pval in pvals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    pvals_out = [t.pval for t in tracers_out]
    unk_tracers_in  = [t for t in tracers_in  if t.pval.is_unknown]
    unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
    jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
  return jaxpr, pvals_out, consts

接下来我们需要实现PartialEvalTrace及其PartialEvalTracer。这个解释器将在运行时构建一个jaxpr,同时跟踪数据依赖性。为此,它在表示阶段化值的PartialEvalTracer节点和表示如何从其他值计算某些值的公式的JaxprRecipe节点之间构建一个二分有向无环图(DAG)。一种配方是JaxprEqnRecipe,对应于JaxprEqn的原始应用,但我们也有用于常量和lambda绑定器的配方类型。

from weakref import ref, ReferenceType

class LambdaBindingRecipe(NamedTuple):
  pass

class ConstRecipe(NamedTuple):
  val: Any

class JaxprEqnRecipe(NamedTuple):
  prim: Primitive
  tracers_in: list['PartialEvalTracer']
  params: dict[str, Any]
  avals_out: list[ShapedArray]
  tracer_refs_out: list['ReferenceType[PartialEvalTracer]']

JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
  pval: PartialVal
  recipe: JaxprRecipe | None

  def __init__(self, trace, pval, recipe):
    self._trace = trace
    self.pval = pval
    self.recipe = recipe

  aval = property(lambda self: self.pval.aval)

  def full_lower(self):
    if self.pval.is_known:
      return full_lower(self.pval.const)
    return self

PartialEvalTrace包含构建JaxprRecipes和PartialEvalTracers图的逻辑。每个参数对应一个LambdaBindingRecipe叶节点,每个常量是一个持有常量引用的ConstRecipe叶节点。所有其他跟踪器和配方都来自process_primitive,它使用JaxprEqnRecipes形成跟踪器。

对于大多数原语,process_primitive的逻辑很简单:如果所有输入都已知,我们就可以将原语绑定到已知值(在Python中求值),并避免形成与输出对应的跟踪器。如果任何输入未知,则我们将其阶段化为一个表示原语应用的JaxprEqnRecipe。为了构建表示未知输出的跟踪器,我们需要avals,我们从抽象求值规则中获取它们。(请注意,跟踪器引用JaxprEqnRecipes,而JaxprEqnRecipes引用跟踪器;我们通过使用weakrefs来避免循环垃圾。)

process_primitive逻辑适用于大多数原语,但xla_call_p需要递归处理。因此,我们在partial_eval_rules字典中特殊处理其规则。

class PartialEvalTrace(Trace):
  def new_arg(self, pval: PartialVal) -> Any:
    return PartialEvalTracer(self, pval, LambdaBindingRecipe())

  def lift(self, val: Any) -> PartialEvalTracer:
    return PartialEvalTracer(self, PartialVal.known(val), None)
  pure = lift

  def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
    if tracer.pval.is_unknown:
      return tracer
    else:
      pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
      return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))

  def process_primitive(self, primitive, tracers, params):
    if all(t.pval.is_known for t in tracers):
      return bind(primitive, *map(full_lower, tracers), **params)
    rule = partial_eval_rules.get(primitive)
    if rule: return rule(self, tracers, **params)
    tracers_in = [self.instantiate_const(t) for t in tracers]
    avals_in = [t.aval for t in tracers_in]
    avals_out = abstract_eval_rules[primitive](*avals_in, **params)
    tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
                   for aval in avals_out]
    eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
                         map(ref, tracers_out))
    for t in tracers_out: t.recipe = eqn
    return tracers_out

partial_eval_rules = {}

现在我们可以使用PartialEvalTrace构建jaxprs的图表示,我们需要一种机制将图表示转换为标准jaxpr。jaxpr对应于图的拓扑排序。

def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
                     tracers_out: list[PartialEvalTracer]):
  tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
                                   for t in tracers_in}
  constvar_to_val: dict[int, Any] = {}
  constid_to_var: dict[int, Var] = {}
  processed_eqns: set[int] = set()
  eqns: list[JaxprEqn] = []
  for t in toposort(tracers_out, tracer_parents):
    if isinstance(t.recipe, LambdaBindingRecipe):
      assert id(t) in set(map(id, tracers_in))
    elif isinstance(t.recipe, ConstRecipe):
      val = t.recipe.val
      var = constid_to_var.get(id(val))
      if var is None:
        aval = raise_to_shaped(get_aval(val))
        var = constid_to_var[id(val)] = Var(aval)
        constvar_to_val[var] = val
      tracer_to_var[id(t)] = var
    elif isinstance(t.recipe, JaxprEqnRecipe):
      if id(t.recipe) not in processed_eqns:
        eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
        processed_eqns.add(id(t.recipe))
    else:
      raise TypeError(t.recipe)

  constvars, constvals = unzip2(constvar_to_val.items())
  in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
  out_vars = [tracer_to_var[id(t)] for t in tracers_out]
  jaxpr = Jaxpr(in_binders, eqns, out_vars)
  typecheck_jaxpr(jaxpr)
  return jaxpr, constvals

def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
                  ) -> JaxprEqn:
  inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
  out_binders = [Var(aval) for aval in recipe.avals_out]
  for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
    if t_ref() is not None: tracer_to_var[id(t_ref())] = var
  return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)

def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
  return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []

隐藏代码单元格源

def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
  if not out_nodes: return []
  out_nodes = remove_duplicates(out_nodes)

  child_counts = {}
  stack = list(out_nodes)
  while stack:
    node = stack.pop()
    if id(node) in child_counts:
      child_counts[id(node)] += 1
    else:
      child_counts[id(node)] = 1
      stack.extend(parents(node))
  for node in out_nodes:
    child_counts[id(node)] -= 1

  sorted_nodes = []
  childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
  while childless_nodes:
    node = childless_nodes.pop()
    sorted_nodes.append(node)
    for parent in parents(node):
      if child_counts[id(parent)] == 1:
        childless_nodes.append(parent)
      else:
        child_counts[id(parent)] -= 1

  sorted_nodes = sorted_nodes[::-1]
  check_toposort(sorted_nodes, parents)
  return sorted_nodes

def remove_duplicates(lst):
  seen = set()
  return [x for x in lst if id(x) not in seen and not seen.add(id(x))]

def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
  seen = set()
  for node in nodes:
    assert all(id(parent) in seen for parent in parents(node))
    seen.add(id(node))

现在我们可以线性化了!

y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454

为了处理linearize-of-jit,我们仍然需要为xla_call_p编写部分求值规则。除了跟踪器记账外,主要任务是对jaxpr执行部分求值,将其“解压缩”成两个jaxprs。

实际上需要编写两条规则:一条用于跟踪时部分求值,我们称之为xla_call_partial_eval,另一条用于jaxprs的部分求值,我们称之为xla_call_peval_eqn

def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
  del num_consts  # Unused
  in_unknowns = [not t.pval.is_known for t in tracers]
  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
  outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in jaxpr2.outs]
  eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
                       dict(jaxpr=jaxpr2, num_consts=0),
                       [v.aval for v in jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval

def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
                       instantiate: list[bool] | None = None,
                       ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
  env: dict[Var, bool] = {}
  residuals: set[Var] = set()

  def read(x: Atom) -> bool:
    return type(x) is Var and env[x]

  def write(unk: bool, v: Var) -> None:
    env[v] = unk

  def new_res(x: Atom) -> Atom:
    if type(x) is Var: residuals.add(x)
    return x

  eqns1, eqns2 = [], []
  map(write, in_unknowns, jaxpr.in_binders)
  for eqn in jaxpr.eqns:
    unks_in = map(read, eqn.inputs)
    rule = partial_eval_jaxpr_rules.get(eqn.primitive)
    if rule:
      eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
      eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
      map(write, unks_out, eqn.out_binders)
    elif any(unks_in):
      inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
      eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
      map(partial(write, True), eqn.out_binders)
    else:
      eqns1.append(eqn)
      map(partial(write, False), eqn.out_binders)
  out_unknowns = map(read, jaxpr.outs)
  if instantiate is not None:
    for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
      if inst and not uk: new_res(v)
    out_unknowns = map(op.or_, out_unknowns, instantiate)

  residuals, num_res = list(residuals), len(residuals)
  assert all(type(v) is Var for v in residuals), residuals

  ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
  outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)

  jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
  jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
  typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)

  return jaxpr1, jaxpr2, out_unknowns, num_res

def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
  jaxprty = typecheck_jaxpr(jaxpr)    # (a1,  a2) -> (b1, b2 )
  jaxpr1ty = typecheck_jaxpr(jaxpr1)  #  a1       -> (b1, res)
  jaxpr2ty = typecheck_jaxpr(jaxpr2)  # (res, a2) -> b2

  a1, a2 = partition_list(unks_in, jaxprty.in_types)
  b1, b2 = partition_list(unks_out, jaxprty.out_types)
  b1_, res = split_list(jaxpr1ty.out_types, len(b1))
  res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
  b2_ = jaxpr2ty.out_types

  if jaxpr1ty.in_types != a1: raise TypeError
  if jaxpr2ty.out_types != b2: raise TypeError
  if b1 != b1_: raise TypeError
  if res != res_: raise TypeError
  if a2 != a2_: raise TypeError
  if b2 != b2_: raise TypeError

partial_eval_jaxpr_rules = {}

def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                       ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
  jaxpr = eqn.params['jaxpr']
  jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
  ins1, ins2 = partition_list(unks_in, eqn.inputs)
  out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
  residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
  eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
                  out_binders1 + residuals)
  eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
                  dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
  return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn

有了这些,我们就可以随意组合linearizejit了。

@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
  y = sin(x) * 2.
  z = g(x, y)
  return z

@jit
def g(x, y):
  return cos(x) + y

y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758

vjpgrad#

vjp转换与linearize非常相似。它的类型签名是类似的

linearize : (a -> b) -> a -> (b, T a -o T b)
vjp       : (a -> b) -> a -> (b, T b -o T a)

唯一的区别是我们在返回之前对计算的线性部分进行了转置,使其类型从T a -o T b变为T b -o T a。也就是说,我们基本上将vjp实现为:

def vjp(f, x):
  y, f_lin = linearize(f, x)
  f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
  return y, f_vjp

既然我们拥有的是jaxpr形式的线性计算,而不仅仅是Python可调用对象,我们就可以将转置转换实现为jaxpr解释器。

def vjp_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)  # linearize
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
  f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
  return primals_out, f_vjp

def vjp(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_vjp(*cotangents_out):
    cotangents_out_flat, _ = tree_flatten(cotangents_out)
    cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
    return tree_unflatten(in_tree, cotangents_in_flat)

  return primals_out, f_vjp

class UndefPrimal(NamedTuple):
  aval: ShapedArray

register_pytree_node(UndefPrimal,
                     lambda u: (u.aval, ()),
                     lambda aval, _: UndefPrimal(aval))

我们使用UndefPrimal实例来指示我们想要对其进行转置的参数。这些情况的出现是因为通常,为了明确闭包值,我们希望将类型为a -> b -o c的函数转置为类型为a -> c -o b的函数。更一般地,函数在其线性方面的输入可能分散在参数列表中。因此,我们使用UndefPrimal来指示线性位置。我们将UndefPrimal注册为pytree节点,因为pytree机制提供了一种方便的方法来从参数列表中修剪这些占位符。

接下来,我们可以编写eval_jaxpr_transposed,以及适用于至少在一个参数上是线性的所有原语的转置规则。

# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
                          ) -> list[Any]:
  primal_env: dict[Var, Any] = {}
  ct_env: dict[Var, Any] = {}

  def read_primal(x: Atom) -> Any:
    return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val

  def write_primal(v: Var, val: Any) -> None:
    if type(val) is not UndefPrimal:
      primal_env[v] = val

  def read_cotangent(v: Var) -> Any:
    return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))

  def write_cotangent(x: Atom, val: Any):
    if type(x) is Var and val is not None:
      ct_env[x] = add(ct_env[x], val) if x in ct_env else val

  map(write_primal, jaxpr.in_binders, args)
  map(write_cotangent, jaxpr.outs, cotangents)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.inputs)
    cts_in = map(read_cotangent, eqn.out_binders)
    rule = transpose_rules[eqn.primitive]
    cts_out = rule(cts_in, *primals_in, **eqn.params)
    map(write_cotangent, eqn.inputs, cts_out)

  return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
          if type(x) is UndefPrimal]

transpose_rules = {}
def mul_transpose_rule(cts, x, y):
  z_bar, = cts
  assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
  return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule

def neg_transpose_rule(cts, x):
  ybar, = cts
  assert type(x) is UndefPrimal
  return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule

def add_transpose_rule(cts, x, y):
  z_bar, = cts
  return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule

def reduce_sum_transpose_rule(cts, x, *, axis):
  y_bar, = cts
  return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule

def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
  del num_consts  # Unused
  undef_primals = [type(x) is UndefPrimal for x in invals]
  transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
  residuals, _ = partition_list(undef_primals, invals)
  outs = bind(xla_call_p, *new_consts, *residuals, *cts,
              jaxpr=transposed_jaxpr, num_consts=len(new_consts))
  outs = iter(outs)
  return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule

@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
                    ) -> tuple[Jaxpr, list[Any]]:
  avals_in, avals_out = typecheck_jaxpr(jaxpr)
  traceable = partial(eval_jaxpr_transposed, jaxpr)
  args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
  trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
  typecheck_jaxpr(trans_jaxpr)
  return trans_jaxpr, consts

既然我们已经可以线性化和转置,我们最终可以编写grad了。

def grad(f):
  def gradfun(x, *xs):
    y, f_vjp = vjp(f, x, *xs)
    if np.shape(y) != (): raise TypeError
    x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
    return x_bar
  return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
  y = x * 2.
  z = g(y)
  return z

@jit
def g(x):
  return cos(x) * 2.

print(grad(f)(3.))
1.1176619927957034

这里有一个组合性压力测试的例子

# from core_test.py fun_with_nested_calls_2
def foo(x):
  @jit
  def bar(y):
    def baz(w):
      q = jit(lambda x: y)(x)
      q = q + jit(lambda: y)()
      q = q + jit(lambda y: w + y)(y)
      q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
      return q
    p, t = jvp(baz, (x + 1.0,), (y,))
    return t + (x * p)
  return bar(x)

def assert_allclose(*vals):
  for v1, v2 in zip(vals[:-1], vals[1:]):
    np.testing.assert_allclose(v1, v2)

ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)

deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)

hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)

第5部分:控制流原语cond#

接下来我们将添加用于阶段化控制流的高阶原语。这些原语类似于第3部分中的jit(另一个高阶原语),但不同之处在于它们由多个可调用对象而不是单个可调用对象进行参数化。

添加cond#

我们引入一个cond原语来表示jaxpr内部一个函数或另一个函数的条件应用。我们将cond的类型写为Bool -> (a -> b) -> (a -> b) -> a -> b。换句话说,cond接受一个表示谓词的布尔值和两个类型相同的函数。根据谓词的值,它将其中一个函数应用于其最终参数。

在Python中,我们将其表示为一个函数,该函数本身接受两个函数作为参数。与jit一样,第一步是对其可调用参数调用make_jaxpr,将它们转换为jaxprs。

def cond(pred, true_fn, false_fn, *operands):
  avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
  true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
  false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
  if out_tree != out_tree_: raise TypeError
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
    raise TypeError
  outs = bind_cond(pred, *true_consts, *false_consts, *operands,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')

def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
                       ) -> tuple[Jaxpr, Jaxpr]:
  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
  assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
  consts1, rest1 = split_list(jaxpr1.in_binders, n1)
  consts2, rest2 = split_list(jaxpr2.in_binders, n2)
  new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
  new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
  return new_jaxpr1, new_jaxpr2

def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
  assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
  return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)

我们要求true_jaxprfalse_jaxpr具有相同的类型,但由于它们可能闭包不同的常量(并且jaxprs只能表示封闭项,即不能有自由变量,而是经过闭包转换),我们需要使用辅助函数_join_jaxpr_consts来使两个jaxpr的输入绑定器列表保持一致。(为了更经济,我们可以尝试识别具有相同形状的常量对,但我们只是简单地连接常量列表。)

接下来我们可以转而为cond添加解释器规则。其求值规则很简单

def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
  if pred:
    return eval_jaxpr(true_jaxpr, operands)
  else:
    return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3

对于它的JVP和vmap规则,我们只需要调用我们为jit创建的相同jvp_jaxprvmap_jaxpr工具,然后再次执行_join_jaxpr_consts

def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
  pred, *primals = primals
  _   , *tangents = tangents
  true_jaxpr , true_consts  = jvp_jaxpr(true_jaxpr)
  false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
  outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  primals_out, tangents_out = split_half(outs)
  return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
  pred    , *vals_in = vals_in
  pred_dim, *dims_in = dims_in
  if pred_dim is not not_mapped: raise NotImplementedError  # TODO
  true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
  false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
  outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]

请注意,我们目前不支持谓词值本身被批处理的情况。在JAX主线版本中,我们通过将条件转换为选择原语来处理这种情况。只要true_funfalse_fun不涉及任何有副作用的原语,这种转换在语义上是正确的。

这里没有体现,但在JAX主线版本中存在另一个问题是,对两个相同类型的jaxpr应用转换可能会导致不同类型的jaxpr。例如,将主线JAX版本的vmap_jaxpr应用于恒等函数jaxpr

{ lambda a:float32[] .
  let
  in ( a ) }

如果批处理大小为10,将导致一个带有批处理输出的jaxpr,类型为[float32[10]] -> [float32[10]],而将其应用于零函数jaxpr

{ lambda a:float32[] .
  let
  in ( 0. ) }

将导致一个带有非批处理输出的jaxpr,类型为[float32[10]] -> [float32[]]。这是一种优化,旨在避免不必要的批处理。但这意味着在cond中,我们需要额外一步来连接两个转换后的jaxpr,使其具有一致的输出类型。我们在这里不需要这个步骤,因为我们选择vmap_jaxpr总是对所有输出沿主轴进行批处理。

接下来我们可以转向抽象求值和XLA降低规则

def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
  if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
  jaxpr_type = typecheck_jaxpr(true_jaxpr)
  if jaxpr_type != typecheck_jaxpr(false_jaxpr):
    raise TypeError
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval

def cond_translation(c, in_avals, out_avals, in_vals, *, true_jaxpr, false_jaxpr):
  del in_avals  # Unused
  pred, *in_vals = in_vals

  op = hlo.IfOp([aval_to_ir_type(aval) for aval in out_avals], pred)
  with ir.InsertionPoint(op.true_branch.blocks.append()):
    hlo.return_(jaxpr_subcomp(c, true_jaxpr, in_vals))
  with ir.InsertionPoint(op.false_branch.blocks.append()):
    hlo.return_(jaxpr_subcomp(c, false_jaxpr, in_vals))
  return op.results

hlo_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2

最后,为了支持反向模式自动微分,我们需要部分求值和转置规则。对于部分求值,我们需要引入另一个jaxpr处理工具_join_jaxpr_res,以处理对true_funfalse_fun应用部分求值通常会导致不同残余值的事实。我们使用_join_jaxpr_res来使转换后的jaxpr的输出类型保持一致(而_join_jaxpr_consts处理了输入类型)。

def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
  pred_tracer, *tracers = tracers
  assert pred_tracer.pval.is_known
  pred = pred_tracer.pval.const
  in_uks = [not t.pval.is_known for t in tracers]

  *jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs

  known_tracers, unknown_tracers = partition_list(in_uks, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind_cond(pred, *known_vals,
                        true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
  outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
  pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in t_jaxpr2.outs]
  eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
                       dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
                       [v.aval for v in t_jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval

def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
                       ) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
  _, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
  _, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
  out_uks = map(op.or_, t_out_uks, f_out_uks)

  t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
  f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)

  t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
  t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
  assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
  assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
  num_res = t_nres + f_nres

  return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res

def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
                    ) -> tuple[Jaxpr, Jaxpr]:
  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
  out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
  out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
  assert out_types1 == out_types2
  outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
  outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
  zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
  zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
  new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
  new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
  return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                   ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
  pred_unk, *unks_in = unks_in
  assert not pred_unk
  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
  *jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
  ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
  outs1, outs2 = partition_list(unks_out, eqn.out_binders)
  residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
  eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
                  dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
                  outs1 + residuals)
  eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
                  dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
                  outs2)
  res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
  return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14

转置是transpose_jaxpr的相当直接的应用。

def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
  undef_primals = tuple(type(x) is UndefPrimal for x in invals)
  true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
  false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  res = [x for x in invals if type(x) is not UndefPrimal]
  outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  outs = iter(outs)
  return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0

隐藏代码单元格源

def pprint_cond(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
  new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
  rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                     for x in eqn.inputs)))
  return vcat([lhs >> pp(' = ') >> rhs,
               pp_jaxpr(true_jaxpr).indent(2),
               pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond