在 JAX 中编写自定义 Jaxpr 解释器#
JAX 提供了几种可组合的函数变换(jit、grad、vmap 等),使编写简洁、加速的代码成为可能。
在这里,我们通过编写一个自定义 Jaxpr 解释器,展示如何向系统中添加自己的函数变换。这样,我们就可以免费获得与其他所有变换的可组合性。
此示例使用内部 JAX API,这些 API 随时可能发生更改。除非在API 文档中,否则应假定所有内容均为内部 API。
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random
JAX 在做什么?#
JAX 为数值计算提供了一个类似 NumPy 的 API,可以直接使用,但 JAX 的真正威力来自于可组合的函数变换。以 jit 函数变换为例,它接收一个函数并返回一个语义上相同的函数,但该函数会被 XLA 惰性编译以用于加速器。
x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)
当我们调用 fast_f 时,会发生什么?JAX 会追踪函数并构建一个 XLA 计算图。然后,该图会被 JIT 编译并执行。其他变换的工作方式类似,它们首先追踪函数,然后以某种方式处理输出追踪。要了解更多关于 JAX 追踪机制的信息,您可以参考 README 中的“工作原理”部分。
Jaxpr 追踪器#
Jax 中一个特别重要的追踪器是 Jaxpr 追踪器,它将操作记录到 Jaxpr(Jax 表达式)中。Jaxpr 是一种可以像微型函数式编程语言一样进行评估的数据结构,因此 Jaxpr 是函数变换的有用中间表示。
为了初步了解 Jaxpr,可以考虑 make_jaxpr 变换。 make_jaxpr 本质上是一个“美化打印”变换:它将一个函数转换为一个函数,该函数接收示例参数后,会生成其计算的 Jaxpr 表示。 make_jaxpr 对于调试和内省非常有用。让我们用它来看看一些示例 Jaxpr 的结构。
def examine_jaxpr(closed_jaxpr):
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=140333297167360):int32[]]
outvars: [Var(id=140333297243136):int32[]]
constvars: []
equation: [Var(id=140333297167360):int32[], Literal(TypedInt(1, dtype=int32))] add [Var(id=140333297243136):int32[]] {}
jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1:i32[] in (b,) }
bar
=====
invars: [Var(id=140333297464768):float32[5,10], Var(id=140333297464896):float32[5], Var(id=140333297464960):float32[10]]
outvars: [Var(id=140333297554432):float32[5], Var(id=140333297464960):float32[10]]
constvars: []
equation: [Var(id=140333297464768):float32[5,10], Var(id=140333297464960):float32[10]] dot_general [Var(id=140333297550400):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32'), 'out_sharding': None}
equation: [Var(id=140333297550400):float32[5], Var(id=140333297464896):float32[5]] add [Var(id=140333297553728):float32[5]] {}
equation: [Literal(TypedNdArray(1., dtype=float32))] broadcast_in_dim [Var(id=140333297554304):float32[5]] {'shape': (5,), 'broadcast_dimensions': (), 'sharding': None}
equation: [Var(id=140333297553728):float32[5], Var(id=140333297554304):float32[5]] add [Var(id=140333297554432):float32[5]] {}
jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
d:f32[5] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] a c
e:f32[5] = add d b
f:f32[5] = broadcast_in_dim[
broadcast_dimensions=()
shape=(5,)
sharding=None
] 1.0:f32[]
g:f32[5] = add e f
in (g, c) }
jaxpr.invars— Jaxpr 的invars是 Jaxpr 输入变量的列表,类似于 Python 函数中的参数。jaxpr.outvars— Jaxpr 的outvars是 Jaxpr 返回的变量。每个 Jaxpr 都有多个输出。jaxpr.constvars—constvars是也作为 Jaxpr 输入的变量列表,但对应于追踪中的常量(我们稍后将详细介绍)。jaxpr.eqns— 一个方程列表,本质上是 let 绑定。每个方程包含输入变量列表、输出变量列表和一个原语,用于评估输入以产生输出。每个方程还有一个params,一个参数字典。
总而言之,Jaxpr 封装了一个简单的程序,该程序可以使用输入进行评估以产生输出。我们稍后将详细介绍如何做到这一点。现在要记住的重要一点是,Jaxpr 是一个我们可以随意操作和评估的数据结构。
Jaxpr 有什么用?#
Jaxpr 是易于转换的简单程序表示。由于 Jax 允许我们将 Jaxpr 从 Python 函数中分阶段提取出来,因此它为我们提供了一种转换用 Python 编写的数值程序的方法。
你的第一个解释器:invert#
让我们尝试实现一个简单的函数“求逆器”,它接收原始函数的输出并返回产生这些输出的输入。目前,我们只关注由其他可逆的一元函数组成的一些简单的一元函数。
目标
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
我们将通过以下方式实现这一点:(1)将 f 追踪到 Jaxpr,然后(2)反向解释 Jaxpr。在反向解释 Jaxpr 时,对于每个方程,我们将在表中查找原语的逆并应用它。
1. 追踪函数#
让我们使用 make_jaxpr 将函数追踪到 Jaxpr。
# Importing Jax functions useful for tracing/interpreting.
from functools import wraps
from jax import lax
from jax.extend import core
from jax._src.util import safe_map
jax.make_jaxpr 返回一个闭合的 Jaxpr,这是一个与追踪中的常量(literals)绑定的 Jaxpr。
def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]
2. 评估 Jaxpr#
在编写自定义 Jaxpr 解释器之前,让我们先实现“默认”解释器 eval_jaxpr,它按原样评估 Jaxpr,计算出与原始、未转换的 Python 函数相同的值。
为此,我们首先创建一个环境来存储每个变量的值,并在评估 Jaxpr 中的每个方程时更新该环境。
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable -> value
env = {}
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Bind args and consts to environment
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
请注意,即使原始函数不是列表,eval_jaxpr 也会始终返回一个扁平列表。
此外,此解释器不处理高阶原语(如 jit 和 pmap),本指南不涵盖这些内容。您可以参考 core.eval_jaxpr(链接)来查看此解释器未涵盖的边界情况。
自定义 inverse Jaxpr 解释器#
一个 inverse 解释器看起来与 eval_jaxpr 没有太大区别。我们将首先设置一个注册表,它会将原语映射到它们的逆。然后,我们将编写一个自定义解释器,它在注册表中查找原语。
事实证明,这个解释器也与反向模式自动微分中使用的“转置”解释器在此处找到的类似。
inverse_registry = {}
现在,我们将为一些原语注册逆。按照惯例,Jax 中的原语以 _p 结尾,许多流行的原语都位于 lax 模块中。
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh
inverse 将首先追踪函数,然后自定义解释 Jaxpr。让我们设置一个简单的骨架。
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't worry about flattening and
# unflattening arguments.
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
return wrapped
现在我们只需要定义 inverse_jaxpr,它将向后遍历 Jaxpr,并在可能时反转原语。
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Args now correspond to Jaxpr outvars
safe_map(write, jaxpr.outvars, args)
safe_map(write, jaxpr.constvars, consts)
# Looping backward
for eqn in jaxpr.eqns[::-1]:
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
就这样!
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
重要的是,你可以追踪 Jaxpr 解释器。
jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
这就是向系统中添加新变换所需的所有内容,并且可以免费与其他变换进行组合!例如,我们可以将 jit、vmap 和 grad 与 inverse 一起使用!
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 , 2.2551253, 1.3155028, 1. ], dtype=float32, weak_type=True)
留给读者的练习#
处理具有多个参数且输入部分已知的原语,例如
lax.add_p、lax.mul_p。处理
xla_call和xla_pmap原语,这些原语将无法与当前编写的eval_jaxpr和inverse_jaxpr一起使用。