JAX 内部机制:jaxpr 语言#
Jaxprs 是 JAX 程序的内部中间表示 (IR)。它们是显式类型化的、函数式的、一阶的,并采用代数范式 (ANF)。
从概念上讲,JAX 变换(例如 jax.jit()
或 jax.grad()
)可以理解为首先将要变换的 Python 函数跟踪特化为一种小巧且行为良好的中间形式,然后通过特定于变换的解释规则进行解释。
JAX 能够将如此强大的功能封装在一个小巧的软件包中的原因之一是,它始于一个熟悉且灵活的编程接口(带有 NumPy 的 Python),并利用实际的 Python 解释器完成了大部分繁重的工作,将计算的精髓提炼成一种简单的静态类型表达式语言,其中包含有限的高阶特性。
这种语言就是 jaxpr 语言。jaxpr 术语语法如下所示
jaxpr ::=
{ lambda <binder> , ... .
let <eqn>
...
in ( <atom> , ... ) }
binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>
eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...
并非所有 Python 程序都可以用这种方式处理,但事实证明,许多科学计算和机器学习程序都可以。
在继续之前,请记住并非所有 JAX 变换都像上面描述的那样字面上具体化一个 jaxpr。其中一些,例如求导或批处理,将在跟踪过程中逐步应用变换。尽管如此,如果想了解 JAX 的内部工作原理,或者利用 JAX 跟踪的结果,理解 jaxpr 是很有用的。
jax.core.ClosedJaxpr
#
jaxpr 实例表示一个具有一个或多个类型化参数(输入变量)和一个或多个类型化结果的函数。结果仅依赖于输入变量;没有从封闭作用域捕获的自由变量。输入和输出具有类型,在 JAX 中表示为抽象值。
代码中 jaxpr 有两种相关的表示形式,jax.core.Jaxpr
和 jax.core.ClosedJaxpr
。jax.core.ClosedJaxpr
表示一个部分应用的 jax.core.Jaxpr
,当你使用 jax.make_jaxpr()
检查 jaxpr 时就会得到它。它具有以下字段
jaxpr
:是一个jax.core.Jaxpr
,表示函数的实际计算内容(如下所述)。consts
是一个常量列表。
ClosedJaxpr
最有趣的部分是实际的执行内容,表示为一个 jax.core.Jaxpr
,其打印语法如下
jaxpr ::= { lambda Var* ; Var+.
let Eqn*
in [Expr+] }
其中
jaxpr 的参数显示为两个变量列表,由
;
分隔第一组变量是为了代表已提升的常量而引入的。它们被称为
constvars
,在jax.core.ClosedJaxpr
中,consts
字段包含相应的值。第二组变量,称为
invars
,对应于被跟踪的 Python 函数的输入。
Eqn*
是一个方程列表,定义了引用中间表达式的中间变量。每个方程将一个或多个变量定义为对某些原子表达式应用原语的结果。每个方程仅使用输入变量和由先前方程定义的中间变量。Expr+
:是 jaxpr 的输出原子表达式(字面量或变量)列表。
方程打印如下
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
其中
Var+
是一个或多个中间变量,将定义为原语调用的输出(某些原语可以返回多个值)。Expr+
是一个或多个原子表达式,每个都是变量或字面常量。一个特殊变量unitvar
或字面量unit
,打印为*
,表示在计算的其余部分不需要且已被省略的值。也就是说,单元只是占位符。Param*
是原语的零个或多个命名参数,打印在方括号中。每个参数显示为Name = Value
。
大多数 jaxpr 原语都是一阶的(它们只接受一个或多个 Expr 作为参数)
Primitive := add | sub | sin | mul | ...
最常见的 jaxpr 原语在 jax.lax
模块中有所记录。
例如,下面是为函数 func1
生成的 jaxpr
from jax import make_jaxpr
import jax.numpy as jnp
def func1(first, second):
temp = first + jnp.sin(second) * 3.
return jnp.sum(temp)
print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0:f32[]
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
这里没有 constvars,a
和 b
是输入变量,它们分别对应 first
和 second
函数参数。标量字面量 3.0
被直接内联。 reduce_sum
原语除了操作数 e
之外,还有命名参数 axes
和 input_shape
。
请注意,尽管调用 JAX 的程序在执行时会构建 jaxpr,但 Python 级别的控制流和 Python 级别的函数会正常执行。这意味着即使一个 Python 程序包含函数和控制流,结果 jaxpr 也不必包含控制流或高阶特性。
例如,在跟踪函数 func3
时,JAX 将内联对 inner
的调用和条件语句 if second.shape[0] > 4
,并将生成与之前相同的 jaxpr
def func2(inner, first, second):
temp = first + inner(second) * 3.
return jnp.sum(temp)
def inner(second):
if second.shape[0] > 4:
return jnp.sin(second)
else:
assert False
def func3(first, second):
return func2(inner, first, second)
print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0:f32[]
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
处理 pytrees#
在 jaxpr 中没有元组类型;相反,原语接受多个输入并产生多个输出。当处理具有结构化输入或输出的函数时,JAX 会将它们展平,在 jaxpr 中它们将显示为输入和输出列表。有关更多详细信息,请参阅 Pytrees 教程。
例如,以下代码生成的 jaxpr 与您之前看到的相同(带有两个输入变量,每个输入元组的元素一个)
def func4(arg): # The `arg` is a pair.
temp = arg[0] + jnp.sin(arg[1]) * 3.
return jnp.sum(temp)
print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
c:f32[8] = sin b
d:f32[8] = mul c 3.0:f32[]
e:f32[8] = add a d
f:f32[] = reduce_sum[axes=(0,)] e
in (f,) }
常量变量 (vars)#
jaxpr 中的某些值是常量,因为它们的值不依赖于 jaxpr 的参数。当这些值是标量时,它们直接在 jaxpr 方程中表示。非标量数组常量则被提升到顶层 jaxpr,在那里它们对应于常量变量(“constvars”)。这些 constvars 与其他 jaxpr 参数(“invars”)的区别仅在于簿记约定。
高阶 JAX 原语#
Jaxpr 包含几个高阶 JAX 原语。它们更复杂,因为它们包含子 jaxpr。
cond
原语 (条件语句)#
JAX 会跟踪正常的 Python 条件语句。为了捕获条件表达式以进行动态执行,必须使用 jax.lax.switch()
和 jax.lax.cond()
构造函数,它们具有以下签名
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
这两者都会在内部绑定一个名为 cond
的原语。jaxpr 中的 cond
原语反映了 lax.switch()
更通用的签名:它接受一个整数,表示要执行的分支的索引(钳位到有效索引范围)。
例如
from jax import lax
def one_of_three(index, arg):
return lax.switch(index, [lambda x: x + 1.,
lambda x: x - 2.,
lambda x: x + 3.],
arg)
print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
d:i32[] = clamp 0:i32[] c 2:i32[]
e:f32[] = cond[
branches=(
{ lambda ; f:f32[]. let g:f32[] = add f 1.0:f32[] in (g,) }
{ lambda ; h:f32[]. let i:f32[] = sub h 2.0:f32[] in (i,) }
{ lambda ; j:f32[]. let k:f32[] = add j 3.0:f32[] in (k,) }
)
] d b
in (e,) }
cond
原语有多个参数
branches
是与分支函数对应的 jaxpr。在此示例中,这些函数每个都接受一个输入变量,对应于x
。linear
是一个布尔值元组,由自动微分机制内部使用,用于编码条件语句中哪些输入参数是线性使用的。
上述 cond 原语实例接受两个操作数。第一个 (d
) 是分支索引,然后 b
是将传递给由分支索引选择的 branches
中任何 jaxpr 的操作数 (arg
)。
另一个示例,使用 jax.lax.cond()
from jax import lax
def func7(arg):
return lax.cond(arg >= 0.,
lambda xtrue: xtrue + 3.,
lambda xfalse: xfalse - 3.,
arg)
print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
b:bool[] = ge a 0.0:f32[]
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
d:f32[] = cond[
branches=(
{ lambda ; e:f32[]. let f:f32[] = sub e 3.0:f32[] in (f,) }
{ lambda ; g:f32[]. let h:f32[] = add g 3.0:f32[] in (h,) }
)
] c a
in (d,) }
在这种情况下,布尔谓词被转换为整数索引(0 或 1),而 branches
是分别对应于 false 和 true 分支函数的 jaxpr,顺序是先 false 后 true。同样,每个函数都接受一个输入变量,分别对应于 xfalse
和 xtrue
。
以下示例展示了一个更复杂的情况,其中分支函数的输入是一个元组,并且 false
分支函数包含一个常量 jnp.ones(1)
,该常量被提升为一个 constvar
。
def func8(arg1, arg2): # Where `arg2` is a pair.
return lax.cond(arg1 >= 0.,
lambda xtrue: xtrue[0],
lambda xfalse: jnp.array([1]) + xfalse[1],
arg2)
print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
e:bool[] = ge b 0.0:f32[]
f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
g:f32[1] = cond[
branches=(
{ lambda ; h:i32[1] i:f32[1] j:f32[]. let
k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
l:f32[1] = add k j
in (l,) }
{ lambda ; m:i32[1] n:f32[1] o:f32[]. let in (n,) }
)
] f a c d
in (g,) }
while
原语#
就像条件语句一样,Python 循环在跟踪期间是内联的。如果要捕获循环以进行动态执行,则必须使用几种特殊操作之一:jax.lax.while_loop()
(一个原语)和 jax.lax.fori_loop()
(一个生成 while_loop 原语的辅助函数)
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
在上述签名中,C
代表循环“carry”值的类型。例如,下面是一个 fori_loop
示例
import numpy as np
def func10(arg, n):
ones = jnp.ones(arg.shape) # A constant.
return lax.fori_loop(0, n,
lambda i, carry: carry + ones * 3. + arg,
arg + ones)
print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
c:f32[16] = broadcast_in_dim[
broadcast_dimensions=()
shape=(16,)
sharding=None
] 1.0:f32[]
d:f32[16] = add a c
_:i32[] _:i32[] e:f32[16] = while[
body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
k:i32[] = add h 1:i32[]
l:f32[16] = mul f 3.0:f32[]
m:f32[16] = add j l
n:f32[16] = add m g
in (k, i, n) }
body_nconsts=2
cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
r:bool[] = lt o p
in (r,) }
cond_nconsts=0
] c a 0:i32[] b d
in (e,) }
while
原语接受 5 个参数:c a 0 b d
,如下所示
cond_jaxpr
的 0 个常量(因为cond_nconsts
为 0)body_jaxpr
的 2 个常量(c
和a
)carry 初始值的 3 个参数
scan
原语#
JAX 支持一种特殊的数组元素循环形式(具有静态已知形状)。迭代次数固定这一事实使得这种循环形式易于反向微分。此类循环使用 jax.lax.scan()
函数构建
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
这以 Haskell 类型签名 的形式编写:C
是 scan
进位 (carry) 的类型,A
是输入数组的元素类型,B
是输出数组的元素类型。
以下面的函数 func11
为例
def func11(arr, extra):
ones = jnp.ones(arr.shape) # A constant
def body(carry, aelems):
# carry: running dot-product of the two arrays
# aelems: a pair with corresponding elements from the two arrays
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
c:f32[16] = broadcast_in_dim[
broadcast_dimensions=()
shape=(16,)
sharding=None
] 1.0:f32[]
d:f32[] e:f32[16] = scan[
_split_transpose=False
jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
j:f32[] = mul h i
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
l:f32[] = add k j
m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
n:f32[] = add l m
in (n, g) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False
unroll=1
] b 0.0:f32[] a c
in (d, e) }
linear
参数描述了每个输入变量是否保证在线性化体中线性使用。一旦 scan
经过线性化,更多参数将变为线性。
scan
原语接受 4 个参数:b 0.0 a c
,其中
一个用于 body 的自由变量
一个用于 carry 的初始值
接下来 2 个是 scan 操作的数组
(p)jit
原语#
调用原语来自 JIT 编译,它封装了一个子 jaxpr 以及指定后端和计算应在哪个设备上运行的参数。例如
from jax import jit
def func12(arg):
@jit
def inner(x):
return x + arg * jnp.ones(1) # Include a constant in the inner function.
return arg + inner(arg - 2.)
print(make_jaxpr(func12)(1.))
{ lambda ; a:f32[]. let
b:f32[] = sub a 2.0:f32[]
c:f32[1] = jit[
name=inner
jaxpr={ lambda ; a:f32[] b:f32[]. let
d:f32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 1.0:f32[]
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
f:f32[1] = mul e d
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
c:f32[1] = add g f
in (c,) }
] a b
h:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
i:f32[1] = add h c
in (i,) }