Autodidax:从零开始的 JAX 核心#
您是否曾经想了解 JAX 是如何工作的,但又觉得其实现似乎难以理解?好吧,您运气不错!通过阅读本教程,您将学到 JAX 核心系统中的每一个重要思想。您甚至还会了解我们的一些奇怪的术语!
这是一个正在进行的草稿。 一些重要的组成部分缺失,将在第五部分和第六部分(以及更多?)中添加。这里还有一些我们尚未应用于主系统的简化,但我们将会进行。
第一部分:将转换视为解释器:标准求值、jvp 和 vmap#
我们想转换看起来像这样的函数
def f(x):
y = sin(x) * 2.
z = - y + x
return z
将像 sin 这样的函数以及中缀运算符(mul、add 和 neg)底层算术运算视为基本操作,即处理的基本单元而不是组合。与标准解释不同,标准解释是将基本操作应用于数值输入以产生数值输出,我们希望重写基本操作的应用,让不同的值通过我们的程序。例如,我们可能希望用其 JVP 规则的应用替换每个基本操作的应用,并让原始-切线对通过我们的程序。此外,我们希望能够组合多个转换,从而形成解释器堆栈。
“转换”意味着“以不同的方式解释”。标准解释是将基本操作应用于数值输入以产生数值输出,我们希望重写基本操作的应用,让不同的值通过我们的程序。例如,我们可能希望用其 JVP 规则的应用替换每个基本操作的应用,并让原始-切线对通过我们的程序。此外,我们希望能够组合多个转换,从而形成解释器堆栈。
JAX 核心机制#
我们可以实现解释器堆栈,甚至可以让它们在执行要转换的 Python 函数时全部即时消除。首先,让我们定义这些基本操作,以便我们可以拦截它们的应用。
from typing import NamedTuple
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
return out
我们将在稍后设置数组数据类型和中缀运算符方法。
一个 Primitive 只是一个带有名称的对象,我们将其解释规则(每个转换一个)附加到该对象上。bind 函数是我们的拦截点:它将根据参数是如何在追踪器中装箱以及哪些解释器处于活动状态来确定应用哪个转换规则。
用户代码调用的函数,如 add 和 sin,只是 bind 的包装器。这些包装器允许我们控制参数如何传递给 bind,并且我们特别遵循一个方便的内部约定:当我们调用 bind 时,我们将表示数组数据的变量作为位置参数传递,并将元数据(如 reduce_sum_p 的 axis 参数)通过关键字参数传递。此调用约定简化了一些核心逻辑(因为例如下面定义的 Tracer 类实例只能出现在 bind 的位置参数中)。包装器还可以提供文档字符串!
我们将活动解释器表示为堆栈。该堆栈只是一个简单的 list,并且每个元素都是一个包含整数级别(对应于元素在堆栈中的高度)、解释器类型(我们将称之为 trace_type)以及解释器所需的任何全局数据的可选字段的容器。我们将每个元素称为 MainTrace,尽管也许“Interpreter”会更具描述性。
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
trace_stack.append(main)
try:
yield main
finally:
trace_stack.pop()
当我们准备应用转换时,我们将使用 new_main 将另一个解释器推入堆栈。然后,当我们应用函数中的基本操作时,我们可以认为 bind 首先由堆栈顶部的追踪器(即最高级别的追踪器)解释。如果第一个解释器在其对基本操作的解释规则中本身就绑定了其他基本操作,就像 sin_p 的 JVP 规则可能绑定 cos_p 和 mul_p 一样,那么那些 bind 调用将由下面一级的解释器处理。
解释器堆栈的底部有什么?在底部,我们知道所有转换解释器都已完成,我们只想进行标准求值。因此,在底部我们将放置一个求值解释器。
让我们概述一下解释器的接口,该接口基于 Trace 和 Tracer 基类。一个 Tracer 代表一个装箱的值,可能携带解释器使用的任何额外上下文数据。Trace 处理将值装箱到 Tracers 中,并处理基本操作的应用。
class Trace:
main: MainTrace
def __init__(self, main: MainTrace) -> None:
self.main = main
def pure(self, val): assert False # must override
def lift(self, val): assert False # must override
def process_primitive(self, primitive, tracers, params):
assert False # must override
前两个方法是关于将值装箱到 Tracers 中,这些是流经我们转换的 Python 程序的那些对象。最后一个方法是我们用于解释基本操作应用的回调。
该 Trace 本身除了对其对应的 MainTrace 实例的引用外,不包含任何数据。事实上,在应用转换的过程中,可能会创建和丢弃 Trace 的多个实例,而每个转换的应用只创建一个 MainTrace 实例。
至于 Tracers 本身,每个都携带一个抽象值(并将其上的中缀运算符转发给它),其余的则由转换处理。(Tracers 和 AbstractValues 之间的关系是,每个转换有一个 Tracer,每个基本类型(如数组)至少有一个 AbstractValue。)
import numpy as np
class Tracer:
_trace: Trace
__array_priority__ = 1000
@property
def aval(self):
assert False # must override
def full_lower(self):
return self # default implementation
def __neg__(self): return self.aval._neg(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
def __getattr__(self, name):
try:
return getattr(self.aval, name)
except AttributeError:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
array_abstraction_level = 1
shape: tuple[int, ...]
dtype: np.dtype
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
@property
def ndim(self):
return len(self.shape)
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
@staticmethod
def _nonzero(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
def str_short(self):
return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'
def __hash__(self):
return hash((self.shape, self.dtype))
def __eq__(self, other):
return (type(self) is type(other) and
self.shape == other.shape and self.dtype == other.dtype)
def __repr__(self):
return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"
class ConcreteArray(ShapedArray):
array_abstraction_level = 2
val: np.ndarray
def __init__(self, val):
self.val = val
self.shape = val.shape
self.dtype = val.dtype
@staticmethod
def _bool(tracer):
return bool(tracer.aval.val)
@staticmethod
def _nonzero(tracer):
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
请注意,我们实际上有两个数组的 AbstractValues,它们代表不同的抽象级别。ShapedArray 代表所有具有给定形状和 dtype 的数组的集合。ConcreteArray 代表由单个数组值组成的单例集。
现在我们已经设置了解释器堆栈、用于解释器的 Trace/Tracer API 以及抽象值,我们可以回来实现了 bind。
def bind(prim, *args, **params):
top_trace = find_top_trace(args)
tracers = [full_raise(top_trace, arg) for arg in args]
outs = top_trace.process_primitive(prim, tracers, params)
return [full_lower(out) for out in outs]
主要操作是调用 find_top_trace 来确定哪个解释器应该处理这个基本操作的应用。然后我们调用该顶部追踪器的 process_primitive,以便追踪器可以应用其解释规则。full_raise 的调用只是确保输入被装箱到顶部追踪器的 Tracer 实例中,而 full_lower 的调用是一个可选的优化,以便我们尽可能多地从 Tracers 中解装箱变量。
import operator as op
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
用语言来说,暂时忽略 dynamic_trace 步骤直到第三部分,find_top_trace 返回与输入上的 Tracer 关联的最高级解释器,否则返回堆栈底部的解释器(至少目前是这样,它总是一个求值追踪器)。这与上面的描述有所不同,其中我们总是从运行堆栈顶部的解释器开始,然后向下工作,应用堆栈中的每个解释器。相反,我们只在基本绑定操作的输入变量被装箱到对应于该解释器的 Tracer 时才应用解释器。这个优化允许我们跳过不相关的转换,但内嵌了一个假设,即转换主要遵循数据依赖(除了特殊的堆栈底部解释器,它解释一切)。
另一种选择是让堆栈中的每个解释器都解释每个操作。这值得探索!JAX 在很大程度上是围绕数据依赖设计的,因为这对于自动微分非常自然,而 JAX 的根源在于自动微分。但这可能过度拟合。
def full_lower(val: Any):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
return val
elif val._trace.main.level < level:
return trace.lift(val)
elif val._trace.main.level > level:
raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
else: # val._trace.level == level
raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
full_raise 中的逻辑用于将变量装箱到特定 Trace 的 Tracers 中,根据上下文调用 Trace 上的不同方法:非 Tracer 常量调用 Trace.pure,对于已经是较低级别解释器的 Tracers 的值调用 Trace.lift。这两个方法可以共享相同的实现,但通过在核心逻辑中区分它们,我们可以为 Trace 子类提供更多信息。
JAX 核心到此为止!现在我们可以开始添加解释器了。
求值解释器#
我们将从最简单的解释器开始:将位于解释器堆栈底部的求值解释器。
class EvalTrace(Trace):
pure = lift = lambda self, x: x # no boxing in Tracers needed
def process_primitive(self, primitive, tracers, params):
return impl_rules[primitive](*tracers, **params)
trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
使用这个解释器,我们可以求值用户函数
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(f(3.0))
2.7177599838802657
哇!像走了一个大圈。但这种间接性的好处是,现在我们可以添加一些真正的转换。
正向模式自动微分与 jvp#
首先,一些辅助函数
import builtins
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
def map(f, *xs):
return list(builtins.map(f, *xs))
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(builtins.zip(*args))
用于正向模式自动微分的 Tracer 携带一个原始-切线对。该 Trace 应用 JVP 规则。
class JVPTracer(Tracer):
def __init__(self, trace, primal, tangent):
self._trace = trace
self.primal = primal
self.tangent = tangent
@property
def aval(self):
return get_aval(self.primal)
class JVPTrace(Trace):
pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
jvp_rule = jvp_rules[primitive]
primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {}
请注意,pure 和 lift 都将一个值打包到一个 JVPTracer 中,并带有最少的上下文,即零切线值。
让我们为基本操作添加一些 JVP 规则
def add_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot,) = primals, tangents
return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
最后,我们添加一个转换 API 来启动追踪
def jvp_v1(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
return primal_out, tangent_out
这样,我们就可以进行微分了!
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
return lambda x: jvp_v1(f, (x,), (1.,))[1]
print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
if x > 0.: # Python control flow
return 2. * x
else:
return x
print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0
Pytrees 和展平用户函数的输入和输出#
jvp_v1 的一个限制是它假设用户函数接受数组作为位置参数并产生单个数组作为输出。如果它产生列表作为输出呢?或者接受嵌套容器作为输入呢?在堆栈的每一层处理输入和输出的所有可能容器都会很麻烦。相反,我们可以包装用户函数,使得包装后的版本接受数组作为输入并返回一个展平的数组列表作为输出。包装器只需要解展平其输入,调用用户函数,然后展平输出。
这是我们期望的 jvp 的写法,假设用户总是给我们函数,这些函数接受数组作为输入并生成一个展平的数组列表作为输出。
def jvp_flat(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
return primals_out, tangents_out
为了支持用户函数在输入和输出中具有任意容器,这是我们编写用户界面 jvp 包装器的方法。
def jvp(f, primals, tangents):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
if in_tree != in_tree2: raise TypeError
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, tangents_out
请注意,我们不得不将用户函数输出的树状结构传递回 flatten_fun 的调用者。这些信息直到我们实际运行用户函数后才能获得,因此 flatten_fun 只返回一个指向可变单元的引用,该单元表示为 thunk。这些副作用是安全的,因为我们总是正好运行用户函数一次。(这种安全机制是 linear_util.py 中“linear”名称的原因,指的是线性类型。)
剩下的就是编写 tree_flatten、tree_unflatten 和 flatten_fun。
有了这个 pytree 处理的 jvp 实现,我们现在可以处理任意的输入和输出容器。这在未来的转换中也会很方便!
def f(x):
y = sin(x) * 2.
z = - y + x
return {'hi': z, 'there': [x, y]}
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}
{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]}
向量化批处理与 vmap#
首先,几个辅助函数,一个用于从非映射抽象值生成映射抽象值(通过移除一个轴),一个用于移动批次维度。
def mapped_aval(batch_dim, aval):
shape = list(aval.shape)
del shape[batch_dim]
return ShapedArray(tuple(shape), aval.dtype)
def move_batch_axis(axis_size, src, dst, x):
if src is not_mapped:
target_shape = list(np.shape(x))
target_shape.insert(dst, axis_size)
return broadcast(x, target_shape, [dst])
elif src == dst:
return x
else:
return moveaxis(x, src, dst)
def moveaxis(x, src: int, dst: int):
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return transpose(x, perm)
用于向量化批处理的 Tracer 携带一个批处理值和一个可选整数,指示哪个轴(如果有)是批处理轴。
from typing import Union
class NotMapped: pass
not_mapped = NotMapped()
BatchAxis = Union[NotMapped, int]
class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim: BatchAxis):
self._trace = trace
self.val = val
self.batch_dim = batch_dim
@property
def aval(self):
if self.batch_dim is not_mapped:
return get_aval(self.val)
else:
return mapped_aval(self.batch_dim, get_aval(self.val))
def full_lower(self):
if self.batch_dim is not_mapped:
return full_lower(self.val)
else:
return self
class BatchTrace(Trace):
pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
def process_primitive(self, primitive, tracers, params):
vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
vmap_rule = vmap_rules[primitive]
val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
@property
def axis_size(self):
return self.main.global_data
vmap_rules = {}
这里我们实现了可选的 Tracer.full_lower 方法,它允许我们剥离一个批处理追踪器,如果它不需要的话,因为它不代表一个批处理值。
对于 BatchTrace,类似于 JVPTrace,方法 pure 和 lift 只是将一个值打包到一个 BatchTracer 中,并带有最少的上下文,即 batch_dim 采用哨兵值 not_mapped。请注意,我们使用 MainTrace 的解释器全局数据字段来存储批处理轴大小。
接下来我们可以为每个基本操作定义批处理解释器规则
from functools import partial
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
最后,我们添加一个转换 API 来启动追踪
def vmap_flat(f, in_axes, *args):
axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
if ax is not not_mapped}
with new_main(BatchTrace, axis_size) as main:
trace = BatchTrace(main)
tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
for x, ax in zip(args, in_axes)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
for val_out, bdim in zip(vals_out, bdims_out)]
return outs_transposed
def vmap(f, in_axes):
def batched_f(*args):
args_flat, in_tree = tree_flatten(args)
in_axes_flat, in_tree2 = tree_flatten(in_axes)
if in_tree != in_tree2: raise TypeError
f_flat, out_tree = flatten_fun(f, in_tree)
outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
return tree_unflatten(out_tree(), outs_flat)
return batched_f
def add_one_to_a_scalar(scalar):
assert np.ndim(scalar) == 0
return 1 + scalar
vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)
print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
pushfwd = lambda v: jvp(f, (x,), (v,))[1]
vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
return vmap(pushfwd, (0,))(vecs_in)
def f(x):
return sin(x)
jacfwd(f, np.arange(3.))
array([[ 1. , 0. , -0. ],
[ 0. , 0.54030231, -0. ],
[ 0. , 0. , -0.41614684]])
这就是 jvp 和 vmap 的全部内容!
第二部分:Jaxprs#
接下来的转换是用于即时编译的 jit 和用于反向模式自动微分的 vjp。(grad 只是 vjp 的一个小型包装器。)而 jvp 和 vmap 只需每个 Tracer 携带少量额外上下文,对于 jit 和 vjp,我们需要更丰富的上下文:我们需要表示*程序*。也就是说,我们需要 jaxprs!
Jaxprs 是 JAX 内部程序的中间表示。它们是显式类型、函数式、一阶且处于 ANF 形式。我们需要 jit 的程序表示,因为 jit 的目的是将计算从 Python 中分阶段出来。对于任何我们想要分阶段的计算,我们都需要能够将其表示为数据,并在追踪 Python 函数时构建它。类似地,vjp 需要一种方法来表示反向模式自动微分的后向传递的计算。我们为这两个需求使用相同的 jaxpr 程序表示。
(构建程序表示是最自由的追踪转换形式,因此除了处理原生 Python 控制流的问题外,任何转换都可以通过首先追踪到 jaxpr 然后解释 jaxpr 来实现。)
Jaxpr 数据结构#
jaxpr 项的语法大致是
jaxpr ::=
{ lambda <binder> , ... .
let <eqn>
...
in ( <atom> , ... ) }
binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>
eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...
类型的语法是
jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...
我们如何将它们表示为 Python 数据结构?我们重用 ShapedArrays 来表示类型,并且我们可以用几个 Python 结构来表示项语法。
class Var:
aval: ShapedArray
def __init__(self, aval): self.aval = aval
class Lit:
val: Any
aval: ShapedArray
def __init__(self, val):
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
class JaxprEqn(NamedTuple):
primitive: Primitive
inputs: list[Atom]
params: dict[str, Any]
out_binders: list[Var]
class Jaxpr(NamedTuple):
in_binders: list[Var]
eqns: list[JaxprEqn]
outs: list[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
检查 jaxpr 的类型涉及检查没有未绑定的变量,变量只绑定一次,并且对于每个方程,基本操作的应用类型与输出绑定器类型匹配。
class JaxprType(NamedTuple):
in_types: list[ShapedArray]
out_types: list[ShapedArray]
def __repr__(self):
in_types = ', '.join(aval.str_short() for aval in self.in_types)
out_types = ', '.join(aval.str_short() for aval in self.out_types)
return f'({in_types}) -> ({out_types})'
def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
env: set[Var] = set()
for v in jaxpr.in_binders:
if v in env: raise TypeError
env.add(v)
for eqn in jaxpr.eqns:
in_types = [typecheck_atom(env, x) for x in eqn.inputs]
out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
for out_binder, out_type in zip(eqn.out_binders, out_types):
if not out_type == out_binder.aval: raise TypeError
for out_binder in eqn.out_binders:
if out_binder in env: raise TypeError
env.add(out_binder)
in_types = [v.aval for v in jaxpr.in_binders]
out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
return JaxprType(in_types, out_types)
def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
if isinstance(x, Var):
if x not in env: raise TypeError("unbound variable")
return x.aval
elif isinstance(x, Lit):
return raise_to_shaped(get_aval(x.val))
else:
assert False
我们可以使用一个简单的解释器将由 jaxpr 表示的函数应用于参数。
def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
env: dict[Var, Any] = {}
def read(x: Atom) -> Any:
return env[x] if type(x) is Var else x.val
def write(v: Var, val: Any) -> None:
assert v not in env # single-assignment
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.inputs)
outs = bind(eqn.primitive, *in_vals, **eqn.params)
map(write, eqn.out_binders, outs)
return map(read, jaxpr.outs)
def jaxpr_as_fun(jaxpr: Jaxpr):
return lambda *args: eval_jaxpr(jaxpr, args)
通过在解释器中使用 bind,这个解释器本身就是可追踪的。
使用 tracing 构建 jaxprs#
现在我们有了 jaxprs 作为数据结构,我们需要方法来从追踪 Python 代码生成它们。一般有两种追踪到 jaxpr 的方式;jit 使用一种,vjp 使用另一种。我们将从 jit 使用的那个开始,它也用于控制流原语,如 lax.cond、lax.while_loop 和 lax.scan。
def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
__slots__ = ['aval']
aval: ShapedArray
def __init__(self, trace, aval):
self._trace = trace
self.aval = aval
# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
def new_arg(self, aval: ShapedArray) -> JaxprTracer:
aval = raise_to_shaped(aval)
tracer = self.builder.new_tracer(self, aval)
self.builder.tracer_to_var[id(tracer)] = Var(aval)
return tracer
def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
tracer = self.builder.const_tracers.get(id(val))
if tracer is None:
tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
self.builder.add_const(tracer, val)
return tracer
pure = lift = get_or_make_const_tracer
def process_primitive(self, primitive, tracers, params):
avals_in = [t.aval for t in tracers]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
inputs = [self.builder.getvar(t) for t in tracers]
outvars = [self.builder.add_var(t) for t in out_tracers]
self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
return out_tracers
@property
def builder(self):
return self.main.global_data
# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {}
请注意,我们作为解释器全局数据保留了一个构建器对象,该对象在构建 jaxpr 时跟踪变量、常量和方程。
class JaxprBuilder:
eqns: list[JaxprEqn]
tracer_to_var: dict[int, Var]
const_tracers: dict[int, JaxprTracer]
constvals: dict[Var, Any]
tracers: list[JaxprTracer]
def __init__(self):
self.eqns = []
self.tracer_to_var = {}
self.const_tracers = {}
self.constvals = {}
self.tracers = []
def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
tracer = JaxprTracer(trace, aval)
self.tracers.append(tracer)
return tracer
def add_eqn(self, eqn: JaxprEqn) -> None:
self.eqns.append(eqn)
def add_var(self, tracer: JaxprTracer) -> Var:
assert id(tracer) not in self.tracer_to_var
var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
return var
def getvar(self, tracer: JaxprTracer) -> Var:
var = self.tracer_to_var.get(id(tracer))
assert var is not None
return var
def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
var = self.add_var(tracer)
self.const_tracers[id(val)] = tracer
self.constvals[var] = val
return var
def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
) -> tuple[Jaxpr, list[Any]]:
constvars, constvals = unzip2(self.constvals.items())
t2v = lambda t: self.tracer_to_var[id(t)]
in_binders = constvars + [t2v(t) for t in in_tracers]
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
typecheck_jaxpr(new_jaxpr)
return new_jaxpr, new_consts
JaxprTrace.process_primitive 所需的规则基本上是基本操作应用的类型规则:给定基本操作、其参数以及输入的类型,规则必须生成输出的类型,然后将其与输出 JaxprTracer 一起打包。我们可以使用抽象求值规则来实现相同的目的,即使它们可以更通用(因为抽象求值规则必须接受 ConcreteArray 输入,并且它们只需要返回输出集的上限,它们也可以生成 ConcreteArray 输出)。我们将重用这些抽象求值规则来处理其他 jaxpr 生成的追踪机制,其中潜在的额外通用性很有用。
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
) -> list[ShapedArray]:
axis_ = set(axis)
new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
axes: Sequence[int]) -> list[ShapedArray]:
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
为了检查我们对 jaxprs 的实现,我们可以添加一个 make_jaxpr 转换和一个美观的打印器。
from functools import lru_cache
@lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
let b:float64[] = mul 2.0 a
in ( b ) }
(float64[]) -> (float64[])
但是这里有一个限制:由于 find_top_trace 的工作方式是按数据依赖关系进行的,make_jaxpr_v1 无法分阶段输出它所给的 Python 可调用对象执行的所有基本操作。例如:
jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let
in ( 4.0 ) }
这正是omnistaging所解决的问题。我们希望确保由 make_jaxpr 启动的 JaxprTrace 始终被应用,无论 bind 的任何输入是否装箱在相应的 JaxprTracer 实例中。我们可以通过使用第一部分定义的 dynamic_trace 全局变量来实现这一点。
@contextmanager
def new_dynamic(main: MainTrace):
global dynamic_trace
prev_dynamic_trace, dynamic_trace = dynamic_trace, main
try:
yield
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
with new_dynamic(main):
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let a:float64[] = mul 2.0 2.0
in ( a ) }
这样使用 dynamic_trace 在概念上与存储当前解释器堆栈并用 JaxprTrace 放在堆栈底部开始一个新堆栈相同。也就是说,堆栈中低于 dynamic_trace 的解释器不会被应用(因为 JaxprTrace.process_primitive 不调用 bind),尽管如果正在追踪到 jaxpr 的 Python 可调用对象本身使用转换,那么这些转换可以被推到 JaxprTrace 之上的堆栈中。但是临时存储解释器堆栈会破坏系统状态。dynamic_trace 标签在保持系统状态更简单的情况下实现了相同的目标。
这就完成了 jaxprs 的内容!有了 jaxprs,我们就可以实现 JAX 的其余主要功能了。
第三部分:jit,简化版#
虽然 jit 具有类似转换器的 API,因为它接受一个 Python 可调用对象作为参数,但在底层,它更像是一个高阶原语而不是一个转换器。当一个原语以函数为参数时,它就是*高阶*的。
即时(“最终风格”)和分阶段(“初始风格”)处理#
有两种选项可以处理高阶原语。每种选项都需要不同的追踪方法并产生不同的权衡。
即时处理,其中
bind接受一个 Python 可调用对象作为参数。 我们将形成 jaxpr 的时间推迟到尽可能晚的时候,即直到我们运行解释器堆栈底部的最后一个解释器。这样,我们就可以在解释器堆栈底部插入一个JaxprTrace,从而分阶段化而不是执行所有基本操作。通过这种方法,堆栈中的转换会像往常一样在执行 Python 可调用对象时应用。这种方法可能非常难以实现,但它具有尽可能通用的特点,因为它允许高阶原语不提高其参数的抽象级别,从而允许数据依赖的 Python 控制流。我们将这种方法称为使用“最终风格高阶原语”,它采用了我们迄今为止使用的“最终风格转换”的即时消除。分阶段处理,其中
bind接受一个 jaxpr 作为参数。 在调用bind之前,在原语包装器中,我们可以直接使用make_jaxpr提前形成一个 jaxpr,这样 Python 可调用对象就完全完成了。在这种情况下,make_jaxpr将其JaxprTrace放在解释器堆栈的顶部,并且堆栈中较低的任何转换(可能通过闭包变量追踪器进入)都不会应用于我们追踪它时调用的 Python 可调用对象。(在 Python 可调用对象内部应用的转换会像往常一样应用,被添加到 JaxprTrace 之上的堆栈中。)相反,较低的转换稍后应用于调用原语,并且调用原语的规则必须随后转换 jaxpr 本身。因为我们提前追踪到 jaxpr,所以这种方法不支持数据依赖的 Python 控制流,但它更容易实现。我们将这种类型的高阶原语称为“初始风格高阶原语”,并说它的 jaxpr 处理转换规则是“初始风格转换规则”。
后一种方法适用于 jit,因为我们不需要支持用户提供的 Python 可调用对象中的数据依赖 Python 控制流,因为 jit 的整个目的是将计算从 Python 中分阶段出来,由 XLA 执行。(相比之下,custom_jvp 是一个高阶原语,我们希望在其中支持数据依赖的 Python 控制流。)
历史上,我们阅读了类型化无标签最终解释器论文后开始使用“初始风格”和“最终风格”术语,并开玩笑地将 JAX 称为“无类型有标签最终解释器”的实现。我们不声称能够(或理解)这些术语背后有任何深刻的含义;我们松散地使用“初始风格”来表示“构建 AST 然后转换它”,并使用“最终风格”来表示“在追踪时转换”。但这只是不精确但令人难忘的术语。
使用初始风格方法,这是用户界面 jit 包装器。
def jit(f):
def f_jitted(*args):
avals_in = [raise_to_shaped(get_aval(x)) for x in args]
jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
return tree_unflatten(out_tree, outs)
return f_jitted
xla_call_p = Primitive('xla_call')
对于任何新的原语,我们需要为其提供转换规则,首先是其求值规则。当我们求值 xla_call 原语的应用时,我们希望将计算分阶段到 XLA。这包括将 jaxpr 翻译成 XLA HLO 程序,将参数值传输到 XLA 设备,执行 XLA 程序,然后将结果传回。我们将缓存 XLA HLO 编译,以便对于每个 jit 化的函数,只需要为每个参数形状和 dtype 签名执行一次。
首先,一些实用程序。
class IDHashable:
val: Any
def __init__(self, val):
self.val = val
def __hash__(self) -> int:
return id(self.val)
def __eq__(self, other):
return type(other) is IDHashable and id(self.val) == id(other.val)
接下来,我们将定义 xla_call 的求值规则。
import io
from jax.extend.mlir import ir
from jax.extend.mlir.dialects import func
from jax.extend.mlir.dialects import stablehlo as hlo
from jax._src import xla_bridge as xb
class MlirContext(NamedTuple):
module: ir.Module
symbol_table: ir.SymbolTable
def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
consts, args = args[:num_consts], args[num_consts:]
hashable_consts = tuple(map(IDHashable, consts))
execute = xla_callable(IDHashable(jaxpr), hashable_consts)
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
with ir.Context() as ctx, ir.Location.unknown(ctx):
hlo.register_dialect(ctx)
m = ir.Module.create()
c = MlirContext(m, ir.SymbolTable(m.operation))
with ir.InsertionPoint(c.module.body):
@func.func(*(aval_to_ir_type(aval) for aval in in_avals))
def main(*params):
return jaxpr_subcomp(c, jaxpr, _hlo_consts(consts) + params)
output = io.StringIO()
c.module.operation.print(file=output)
backend = xb.get_backend(None)
compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1])
return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])
def _mlir_dtype(dtype: np.dtype) -> ir.Type:
if np.issubdtype(dtype, np.signedinteger):
return ir.IntegerType.get_signless(np.iinfo(dtype).bits)
elif dtype == np.float32:
return ir.F32Type.get()
elif dtype == np.float64:
return ir.F64Type.get()
else:
raise NotImplementedError("MLIR conversion not implemented for ", dtype)
def aval_to_ir_type(aval: ShapedArray) -> ir.Type:
return ir.RankedTensorType.get(aval.shape, _mlir_dtype(aval.dtype))
def _hlo_const(x: Any) -> ir.Value:
a = np.asarray(x)
if a.dtype == np.bool_:
return hlo.constant(ir.DenseElementsAttr.get(
np.packbits(a, bitorder='little'), type=ir.IntegerType.get_signless(1),
shape=a.shape))
else:
return hlo.constant(ir.DenseElementsAttr.get(a))
def _hlo_consts(consts: list[Any]) -> list[ir.Value]:
unique_consts = {id(cnst): cnst for cnst in consts}
ir_consts = {id_: _hlo_const(cnst) for id_, cnst in unique_consts.items()}
return tuple(ir_consts[id(cnst)] for cnst in consts)
xla_callable 是主要操作,它使用 jaxpr_subcomp 将 jaxpr 编译成 XLA HLO 程序,然后返回一个可调用对象来执行编译后的程序。
def jaxpr_subcomp(c: MlirContext, jaxpr: Jaxpr, args: list[ir.Value]) -> list[ir.Value]:
env: dict[Var, ir.Value] = {}
def read(x: Atom) -> ir.Value:
return env[x] if type(x) is Var else _hlo_const(np.asarray(x.val))
def write(v: Var, val: ir.Value) -> None:
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_avals = [x.aval for x in eqn.inputs]
in_vals = map(read, eqn.inputs)
out_avals = [x.aval for x in eqn.out_binders]
rule = hlo_translations[eqn.primitive]
assert all(isinstance(v, ir.Value) for v in in_vals), in_vals
out_vals = rule(c, in_avals, out_avals, in_vals, **eqn.params)
assert all(isinstance(v, ir.Value) for v in out_vals), out_vals
map(write, eqn.out_binders, out_vals), out_vals
return map(read, jaxpr.outs)
def execute_compiled(compiled, out_avals, *args):
input_bufs = [input_handlers[type(x)](x) for x in args]
out_bufs = compiled.execute(input_bufs)
return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now
return np.asarray(buf)
hlo_translations = {}
请注意,jaxpr_subcomp 具有简单解释器的结构。这是一种常见模式:我们处理 jaxprs 的方式通常是通过解释器。并且与任何解释器一样,我们需要为每个原语提供解释规则。
def direct_translation(op, c, in_avals, out_avals, in_vals):
del c, in_avals, out_avals
return [op(*in_vals)]
hlo_translations[add_p] = partial(direct_translation, hlo.add)
hlo_translations[mul_p] = partial(direct_translation, hlo.multiply)
hlo_translations[neg_p] = partial(direct_translation, hlo.negate)
hlo_translations[sin_p] = partial(direct_translation, hlo.sine)
hlo_translations[cos_p] = partial(direct_translation, hlo.cosine)
def compare_translation(op, c, in_avals, out_avals, in_vals):
del c, out_avals
return [hlo.compare(*in_vals, hlo.ComparisonDirectionAttr.get(op))]
hlo_translations[greater_p] = partial(compare_translation, "GT")
hlo_translations[less_p] = partial(compare_translation, "LT")
def reduce_sum_translation(c, in_avals, out_avals, in_vals, *, axis):
del c
(x_aval,), (out_aval,), (x,) = in_avals, out_avals, in_vals
op = hlo.ReduceOp(
[aval_to_ir_type(out_aval)], [x], [_hlo_const(np.array(0, x_aval.dtype))],
axis)
scalar_type = aval_to_ir_type(ShapedArray((), x_aval.dtype))
reducer_region = op.body.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
hlo.return_([hlo.add(*reducer_region.arguments)])
return op.results
hlo_translations[reduce_sum_p] = reduce_sum_translation
def broadcast_translation(c, in_avals, out_avals, in_vals, *, shape, axes):
del c
(x,), (out_aval,) = in_vals, out_avals
dims_complement = [i for i in range(len(shape)) if i not in axes]
return [hlo.broadcast_in_dim(aval_to_ir_type(out_aval), x, dims_complement)]
hlo_translations[broadcast_p] = broadcast_translation
这样,我们现在就可以使用 jit 来分阶段化、编译和使用 XLA 执行程序了!
@jit
def f(x, y):
print('tracing!')
return sin(x) * cos(y)
z = f(3., 4.) # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306993
@jit
def f(x):
return reduce_sum(x, axis=0)
print(f(np.array([1., 2., 3.])))
6.0
def f(x):
y = sin(x) * 2.
z = - y + x
return z
def deriv(f):
return lambda x: jvp(f, (x,), (1.,))[1]
print( deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344
与其实现 jit 来先追踪到 jaxpr 然后将 jaxpr 降低到 XLA HLO,不如看起来我们可以跳过 jaxpr 步骤,在追踪时直接降低到 HLO。也就是说,也许我们可以用一个 Trace 和 Tracer 来实现 jit,该 Trace 和 Tracer 在每个基本操作绑定时增量地将 XLA HLO 图附加到后面。目前这是正确的,但在我们引入编译的 SPMD 计算时将不可能,因为在那里我们必须在编译程序之前知道所需的副本数量。
我们还没有为 xla_call_p 定义任何转换规则,除了其求值规则。也就是说,我们还不能进行 vmap-of-jit 或 jvp-of-jit 或甚至 jit-of--jit。相反,jit 必须处于“顶层”。让我们来解决这个问题!
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
num_consts=len(new_consts))
n = len(outs) // 2
primals_out, tangents_out = outs[:n], outs[n:]
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
return jvp(jaxpr_as_fun(jaxpr), primals, tangents)
in_avals = [v.aval for v in jaxpr.in_binders]
new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
for v, d in zip(jaxpr.in_binders, bdims_in)]
new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
return new_jaxpr, new_consts
def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
) -> ShapedArray:
if batch_dim is not_mapped:
return aval
else:
shape = list(aval.shape)
shape.insert(batch_dim, axis_size)
return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts):
del num_consts, out_avals
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
with ir.InsertionPoint(c.module.body):
@func.func(*(aval_to_ir_type(aval) for aval in in_avals))
def inner_xla_call(*params):
return jaxpr_subcomp(c, jaxpr, params)
name = c.symbol_table.insert(inner_xla_call.func_op)
return func.CallOp(inner_xla_call.func_op, in_vals).results
hlo_translations[xla_call_p] = xla_call_translation
@jit
def f(x):
print('tracing!')
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0. -0.68294197 0.18140515]
遗漏的一点是数组的设备内存持久化。也就是说,我们定义了 handle_result 将结果传回 CPU 内存作为 NumPy 数组,但通常最好避免仅仅为了将结果用于下一个操作而传输它们。我们可以通过引入一个 Array 类来实现这一点,该类可以包装 XLA 缓冲区并以其他方式模拟 numpy.ndarrays。
def handle_result(aval: ShapedArray, buf): # noqa: F811
return Array(aval, buf)
class Array:
buf: Any
aval: ShapedArray
def __init__(self, aval, buf):
self.aval = aval
self.buf = buf
dtype = property(lambda self: self.aval.dtype)
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf
jax_types.add(Array)
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
第四部分:linearize 和 vjp(以及 grad!)#
linearize 和 vjp 自动微分函数建立在 jvp 之上,但它们也涉及到 jaxprs。这是因为两者都涉及分阶段化或延迟计算。
linearize#
在 linearize 的情况下,我们希望分阶段化 jvp 计算的线性部分。也就是说,用类似 Haskell 的类型签名来说,如果我们有 jvp : (a -> b) -> (a, T a) -> (b, T b),那么我们写 linearize : (a -> b) -> a -> (b, T a -o T b),使用 T a 表示“a 的切线类型”,并使用“棒棒糖” -o 而不是箭头 -> 来表示*线性*函数。我们也用 jvp 来定义 linearize 的语义。
y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)
对于 (y, y_dot) 给出与以下相同的结果:
y, y_dot = jvp(f, (x,), (x_dot,))
其中 f_lin 的应用不会重新执行任何线性化工作。我们将延迟的线性部分 f_lin : T a -o T b 表示为一个 jaxpr。
题外话,既然我们有了线性箭头 -o,我们可以为 jvp 提供一个稍微信息量更丰富的类型。
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
这里我们写 UnrestrictedUse 只是为了表明我们有一个特殊的对,其中第一个元素可以以无限制(非线性)方式使用。结合线性箭头,这种表示法只是用来表示函数 jvp f 以非线性的方式使用其第一个输入,但以线性的方式使用其第二个输入,产生一个相应的非线性输出(可以以非线性的方式使用)和一个线性输出。这种更精细的类型签名编码了 jvp f 中的数据依赖关系,这对于部分求值很有用。
为了从 JVP 构建 f_lin jaxpr,我们需要进行部分求值:我们在追踪时求值所有原始值,但将切线计算分阶段到一个 jaxpr 中。这是我们构建 jaxprs 的第二种方法。但是,与此 jaxpr 示例不同的是,我们希望在求值输入 Python 可调用对象时发生已知值的计算。也就是说,而不是形成整个函数 (a1, a2) -> (b1, b2) 的 jaxpr,先将所有操作从 Python 分阶段出来,然后弄清楚哪些可以立即求值,哪些必须延迟,我们只想形成那些*必须*由于依赖于未知输入而延迟的操作的 jaxpr。在自动微分的上下文中,这是最终能够处理像 grad(lambda x: x**2 if x > 0 else 0.) 这样的函数的功能。Python 控制流可以工作,因为部分求值将原始计算保留在 Python 中。因此,我们的 Trace 和 Tracer 子类必须即时区分什么是可以求值的,什么必须分阶段到一个 jaxpr 中。
首先,一些实用程序。
def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
接下来,我们将通过结合 jvp 和一个通用的部分求值转换(将在下一步添加)来编写 linearize。
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval) # TODO handle integers?
现在我们转向通用的部分求值转换。目标是接受一个 Python 可调用对象和一系列输入,其中一些是已知的,一些是未知的,并生成(1)可以从已知输入计算的所有输出,以及(2)表示 Python 可调用对象计算中必须延迟的部分的 jaxpr,直到其余输入已知。
这种转换很难用类型签名来总结。如果我们假设输入函数的类型签名是 (a1, a2) -> (b1, b2),其中 a1 和 a2 分别代表已知和未知输入,并且其中 b1 仅在数据上依赖于 a1,而 b2 在数据上依赖于 a2,那么我们可能会写:
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
用语言来说,给定类型为 a1、a2 的输入变量,partial_eval 生成类型为 b1 的输出,以及表示完成第二阶段计算所需的中间值的存在量化类型 r 的“残差”值。它还生成一个类型为 (r, a2) -> b2 的函数,该函数接受残差值以及剩余的输入,并生成剩余的输出。
我们喜欢将部分求值视为将一个计算“解压缩”为两个。例如,考虑这个 jaxpr:
{ lambda a:float64[] .
let b:float64[] = sin a
c:float64[] = neg b
in ( c ) }
JVP 的 jaxpr 将是:
{ lambda a:float64[] b:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
e:float64[] = mul d b
f:float64[] = neg c
g:float64[] = neg e
in ( f, g ) }
如果我们想象用第一个输入已知,第二个输入未知来应用此 jaxpr 的部分求值,我们将把 JVP jaxpr“解压缩”为原始和切线 jaxprs:
{ lambda a:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
f:float64[] = neg c
in ( f, d ) }
{ lambda d:float64[] b:float64[] .
let e:float64[] = mul d b
g:float64[] = neg e
in ( g ) }
第二个 jaxpr 代表了我们从 linearize 中获得的线性计算。
然而,与这个 jaxpr 示例不同,我们希望在求值输入 Python 可调用对象时发生已知值的计算。也就是说,而不是形成整个函数 (a1, a2) -> (b1, b2) 的 jaxpr,先将所有操作从 Python 分阶段出来,然后再区分哪些可以现在求值,哪些必须延迟,我们只想形成那些*必须*由于依赖于未知输入而延迟的操作的 jaxpr。在自动微分的上下文中,这是最终能够处理像 grad(lambda x: x**2 if x > 0 else 0.) 这样的函数的功能。Python 控制流可以工作,因为部分求值将原始计算保留在 Python 中。因此,我们的 Trace 和 Tracer 子类必须即时区分什么是可以求值的,什么必须分阶段到一个 jaxpr 中。
首先,我们从一个 PartialVal 类开始,它表示一个可以是已知或未知的值。
class PartialVal(NamedTuple):
aval: ShapedArray
const: Any | None
@classmethod
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
@classmethod
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
部分求值将接受一个表示输入的 PartialVals 列表,并返回一个 PartialVal 输出列表以及一个表示延迟计算的 jaxpr。
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
pvals_out = [t.pval for t in tracers_out]
unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
return jaxpr, pvals_out, consts
接下来我们需要实现 PartialEvalTrace 及其 PartialEvalTracer。这个解释器将在跟踪数据依赖的同时,即时构建一个 jaxpr。为此,它构建了一个双向有向无环图(DAG),在表示分阶段值的 PartialEvalTracer 节点和表示如何从其他值计算某些值的公式的 JaxprRecipe 节点之间。一种配方是 JaxprEqnRecipe,对应于 JaxprEqn 的基本操作应用,但我们也有配方类型用于常量和 lambda 绑定。
from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple):
pass
class ConstRecipe(NamedTuple):
val: Any
class JaxprEqnRecipe(NamedTuple):
prim: Primitive
tracers_in: list['PartialEvalTracer']
params: dict[str, Any]
avals_out: list[ShapedArray]
tracer_refs_out: list['ReferenceType[PartialEvalTracer]']
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
aval = property(lambda self: self.pval.aval)
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
PartialEvalTrace 包含用于构建 JaxprRecipes 和 PartialEvalTracers 图的逻辑。每个参数对应一个 LambdaBindingRecipe 叶节点,每个常量是一个 ConstRecipe 叶节点,它保存对常量的引用。所有其他追踪器和配方都来自 process_primitive,它使用 JaxprEqnRecipes 来形成追踪器。
对于大多数基本操作,process_primitive 逻辑很简单:如果所有输入都已知,那么我们可以将基本操作绑定到已知值(在 Python 中求值)并避免为输出形成追踪器。如果相反,任何输入都是未知的,那么我们将其分阶段到一个表示基本操作应用的 JaxprEqnRecipe 中。为了构建表示未知输出的追踪器,我们需要 avals,这些 avals 来自抽象求值规则。(请注意,追踪器引用 JaxprEqnRecipes,而 JaxprEqnRecipes 引用追踪器;我们使用 weakrefs 来避免循环垃圾回收。)
该 process_primitive 逻辑适用于大多数基本操作,但 xla_call_p 需要递归处理。因此,我们在 partial_eval_rules 字典中特殊处理其规则。
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
else:
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
rule = partial_eval_rules.get(primitive)
if rule: return rule(self, tracers, **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
partial_eval_rules = {}
现在我们有了使用 PartialEvalTrace 构建 jaxprs 的图表示,我们需要一种机制将图表示转换为标准 jaxpr。jaxpr 对应于图的拓扑排序。
def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
tracers_out: list[PartialEvalTracer]):
tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: dict[int, Any] = {}
constid_to_var: dict[int, Var] = {}
processed_eqns: set[int] = set()
eqns: list[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
processed_eqns.add(id(t.recipe))
else:
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
typecheck_jaxpr(jaxpr)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
现在我们可以线性化了!
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454
为了处理 linearize-of-jit,我们仍然需要为 xla_call_p 编写部分求值规则。除了追踪器簿记之外,主要任务是对 jaxpr 进行部分求值,将其“解压缩”成两个 jaxprs。
实际上需要编写两个规则:一个用于追踪时部分求值,我们称之为 xla_call_partial_eval,另一个用于 jaxprs 的部分求值,我们称之为 xla_call_peval_eqn。
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in jaxpr2.outs]
eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
def new_res(x: Atom) -> Atom:
if type(x) is Var: residuals.add(x)
return x
eqns1, eqns2 = [], []
map(write, in_unknowns, jaxpr.in_binders)
for eqn in jaxpr.eqns:
unks_in = map(read, eqn.inputs)
rule = partial_eval_jaxpr_rules.get(eqn.primitive)
if rule:
eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
map(write, unks_out, eqn.out_binders)
elif any(unks_in):
inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
map(partial(write, True), eqn.out_binders)
else:
eqns1.append(eqn)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)
return jaxpr1, jaxpr2, out_unknowns, num_res
def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )
jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)
jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2
a1, a2 = partition_list(unks_in, jaxprty.in_types)
b1, b2 = partition_list(unks_out, jaxprty.out_types)
b1_, res = split_list(jaxpr1ty.out_types, len(b1))
res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
b2_ = jaxpr2ty.out_types
if jaxpr1ty.in_types != a1: raise TypeError
if jaxpr2ty.out_types != b2: raise TypeError
if b1 != b1_: raise TypeError
if res != res_: raise TypeError
if a2 != a2_: raise TypeError
if b2 != b2_: raise TypeError
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
这样,我们就可以任意组合 linearize 和 jit。
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
y = sin(x) * 2.
z = g(x, y)
return z
@jit
def g(x, y):
return cos(x) + y
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758
vjp 和 grad#
vjp 转换的工作方式与 linearize 非常相似。其类型签名是类似的。
linearize : (a -> b) -> a -> (b, T a -o T b)
vjp : (a -> b) -> a -> (b, T b -o T a)
唯一的区别是我们将在返回线性计算之前对其进行转置,使其从类型 T a -o T b 变为类型 T b -o T a。也就是说,我们将 vjp 实现为,本质上是:
def vjp(f, x):
y, f_lin = linearize(f, x)
f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
return y, f_vjp
由于我们有了 jaxpr 形式的线性计算,而不仅仅是一个 Python 可调用对象,我们可以将转置转换实现为一个 jaxpr 解释器。
def vjp_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
return primals_out, f_vjp
def vjp(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_vjp(*cotangents_out):
cotangents_out_flat, _ = tree_flatten(cotangents_out)
cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
return tree_unflatten(in_tree, cotangents_in_flat)
return primals_out, f_vjp
class UndefPrimal(NamedTuple):
aval: ShapedArray
register_pytree_node(UndefPrimal,
lambda u: (u.aval, ()),
lambda aval, _: UndefPrimal(aval))
我们使用 UndefPrimal 实例来指示我们想要相对于哪些参数进行转置。之所以出现这些实例,是因为一般来说,明确闭包变量,我们希望将类型为 a -> b -o c 的函数转置为类型为 a -> c -o b 的函数。更一般地说,函数是线性的参数可能散布在参数列表中。因此,我们使用 UndefPrimal 来指示线性位置。我们将 UndefPrimal 注册为一个 pytree 节点,因为 pytree 机制提供了一种方便的方式来从参数列表中修剪这些占位符。
接下来,我们可以编写 eval_jaxpr_transposed,以及所有至少在一个参数上可以是线性的基本操作的转置规则。
# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
) -> list[Any]:
primal_env: dict[Var, Any] = {}
ct_env: dict[Var, Any] = {}
def read_primal(x: Atom) -> Any:
return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val
def write_primal(v: Var, val: Any) -> None:
if type(val) is not UndefPrimal:
primal_env[v] = val
def read_cotangent(v: Var) -> Any:
return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))
def write_cotangent(x: Atom, val: Any):
if type(x) is Var and val is not None:
ct_env[x] = add(ct_env[x], val) if x in ct_env else val
map(write_primal, jaxpr.in_binders, args)
map(write_cotangent, jaxpr.outs, cotangents)
for eqn in jaxpr.eqns[::-1]:
primals_in = map(read_primal, eqn.inputs)
cts_in = map(read_cotangent, eqn.out_binders)
rule = transpose_rules[eqn.primitive]
cts_out = rule(cts_in, *primals_in, **eqn.params)
map(write_cotangent, eqn.inputs, cts_out)
return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
if type(x) is UndefPrimal]
transpose_rules = {}
def mul_transpose_rule(cts, x, y):
z_bar, = cts
assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule
def neg_transpose_rule(cts, x):
ybar, = cts
assert type(x) is UndefPrimal
return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule
def add_transpose_rule(cts, x, y):
z_bar, = cts
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
outs = bind(xla_call_p, *new_consts, *residuals, *cts,
jaxpr=transposed_jaxpr, num_consts=len(new_consts))
outs = iter(outs)
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
typecheck_jaxpr(trans_jaxpr)
return trans_jaxpr, consts
既然我们可以线性化和转置,我们终于可以编写 grad 了。
def grad(f):
def gradfun(x, *xs):
y, f_vjp = vjp(f, x, *xs)
if np.shape(y) != (): raise TypeError
x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
return x_bar
return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
y = x * 2.
z = g(y)
return z
@jit
def g(x):
return cos(x) * 2.
print(grad(f)(3.))
1.1176619927957034
这是一个组合性压力测试。
# from core_test.py fun_with_nested_calls_2
def foo(x):
@jit
def bar(y):
def baz(w):
q = jit(lambda x: y)(x)
q = q + jit(lambda: y)()
q = q + jit(lambda y: w + y)(y)
q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
return q
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
return bar(x)
def assert_allclose(*vals):
for v1, v2 in zip(vals[:-1], vals[1:]):
np.testing.assert_allclose(v1, v2)
ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)
deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)
hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
第五部分:控制流原语 cond#
接下来,我们将添加用于分阶段化控制流的高阶原语。它们类似于第三部分中的 jit,这是另一个高阶原语,但不同之处在于它们由多个可调用对象而不是一个参数化。
添加 cond#
我们引入了一个 cond 原语来表示在 jaxpr 内部一个函数或另一个函数的条件应用。我们将 cond 的类型写为 Bool -> (a -> b) -> (a -> b) -> a -> b。用语言来说,cond 接受一个表示谓词的布尔值和两个相同类型的函数。根据谓值的值,它将其最后一个参数应用于一个函数或另一个函数。
在 Python 中,我们将其表示为一个本身接受两个函数的函数。与 jit 一样,第一步是对其可调用参数调用 make_jaxpr,将它们转换为 jaxprs。
def cond(pred, true_fn, false_fn, *operands):
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
if out_tree != out_tree_: raise TypeError
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
raise TypeError
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
return new_jaxpr1, new_jaxpr2
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
我们要求 true_jaxpr 和 false_jaxpr 具有相同的类型,但由于它们可能闭合不同的常量(并且因为 jaxprs 只能表示闭合项,即不能有自由变量,而是闭包转换),我们需要使用辅助函数 _join_jaxpr_consts 来使这两个 jaxprs 的输入绑定器列表保持一致。(为了更经济,我们可以尝试识别具有相同形状的常量对,但相反,我们只是连接常量列表。)
接下来我们可以开始为 cond 添加解释器规则。其求值规则很简单。
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
if pred:
return eval_jaxpr(true_jaxpr, operands)
else:
return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3
对于其 JVP 和 vmap 规则,我们只需要调用我们为 jit 创建的相同的 jvp_jaxpr 和 vmap_jaxpr 工具,然后进行一次 _join_jaxpr_consts 传递。
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
pred, *primals = primals
_ , *tangents = tangents
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
pred , *vals_in = vals_in
pred_dim, *dims_in = dims_in
if pred_dim is not not_mapped: raise NotImplementedError # TODO
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]
请注意,我们目前不支持谓词值本身被批处理的情况。在主线 JAX 中,我们通过将条件转换为选择原语来处理这种情况。只要 true_fun 和 false_fun 不涉及任何副作用原语,该转换在语义上就是正确的。
这里没有表示的另一件事,但在主线 JAX 中存在,即对两个类型相同的 jaxprs 应用转换可能会导致类型不同的 jaxprs。例如,应用主线 JAX 版本的 vmap_jaxpr 到身份函数 jaxpr。
{ lambda a:float32[] .
let
in ( a ) }
如果批次大小为 10,则会导致一个具有批处理输出的 jaxpr,类型为 [float32[10]] -> [float32[10]],而将其应用于零函数 jaxpr。
{ lambda a:float32[] .
let
in ( 0. ) }
则会导致一个具有非批处理输出的 jaxpr,类型为 [float32[10]] -> [float32[]]。这是一个优化,目的是不不必要地批处理值。但它意味着在 cond 中,我们需要一个额外的步骤来连接两个转换后的 jaxprs 以保持一致的输出类型。我们在这里不需要这个步骤,因为我们选择了 vmap_jaxpr 始终在领先轴上批处理所有输出。
接下来我们可以转向抽象求值和 XLA 降低规则。
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
jaxpr_type = typecheck_jaxpr(true_jaxpr)
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
raise TypeError
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, out_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused
pred, *in_vals = in_vals
op = hlo.IfOp([aval_to_ir_type(aval) for aval in out_avals], pred)
with ir.InsertionPoint(op.true_branch.blocks.append()):
hlo.return_(jaxpr_subcomp(c, true_jaxpr, in_vals))
with ir.InsertionPoint(op.false_branch.blocks.append()):
hlo.return_(jaxpr_subcomp(c, false_jaxpr, in_vals))
return op.results
hlo_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2
最后,为了支持反向模式自动微分,我们需要部分求值和转置规则。对于部分求值,我们需要引入另一个 jaxpr-munging 工具 _join_jaxpr_res 来处理以下事实:对 true_fun 和 false_fun 应用部分求值通常会产生不同的残差。我们使用 _join_jaxpr_res 来使转换后的 jaxprs 的输出类型一致(而 _join_jaxpr_consts 处理输入类型)。
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
pred_tracer, *tracers = tracers
assert pred_tracer.pval.is_known
pred = pred_tracer.pval.const
in_uks = [not t.pval.is_known for t in tracers]
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind_cond(pred, *known_vals,
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in t_jaxpr2.outs]
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
out_uks = map(op.or_, t_out_uks, f_out_uks)
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
num_res = t_nres + f_nres
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
assert out_types1 == out_types2
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
pred_unk, *unks_in = unks_in
assert not pred_unk
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
outs1 + residuals)
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14
转置是 transpose_jaxpr 的一个相当直接的应用。
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple(type(x) is UndefPrimal for x in invals)
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
res = [x for x in invals if type(x) is not UndefPrimal]
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
outs = iter(outs)
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0