JAX 内部:jaxpr 语言#

Jaxpr 是 JAX 程序的内部中间表示 (IR)。它们是显式类型化、函数式、一阶的,并且采用代数范式 (ANF)。

从概念上讲,可以将 JAX 转换(例如 jax.jit()jax.grad())视为首先将要转换的 Python 函数 trace-specializing 成一个小的、行为良好的中间形式,然后使用特定于转换的解释规则对其进行解释。

JAX 能够将如此强大的功能集成到一个小型的软件包中的原因之一是,它从一个熟悉且灵活的编程接口(带有 NumPy 的 Python)开始,并使用实际的 Python 解释器来完成大部分繁重的工作,将计算的本质提炼成一种简单的静态类型表达式语言,该语言具有有限的高阶功能。

该语言就是 jaxpr 语言。jaxpr 术语语法如下所示

jaxpr ::=
  { lambda <binder> , ... .
    let <eqn>
        ...
    in ( <atom> , ... ) }

binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>

eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...

并非所有 Python 程序都可以通过这种方式处理,但事实证明,许多科学计算和机器学习程序都可以。

在继续之前,请记住,并非所有 JAX 转换都会按上述方式实际物化 jaxpr。其中一些转换(例如微分或批处理)将在跟踪期间以增量方式应用转换。然而,如果想了解 JAX 的内部工作原理,或者利用 JAX 跟踪的结果,那么理解 jaxpr 是有用的。

jax.core.ClosedJaxpr#

jaxpr 实例表示一个具有一个或多个类型化参数(输入变量)和一个或多个类型化结果的函数。结果仅取决于输入变量;没有从封闭作用域捕获的自由变量。输入和输出具有类型,在 JAX 中,类型表示为抽象值。

在 jaxpr 的代码中有两种相关的表示形式,jax.core.Jaxprjax.core.ClosedJaxprjax.core.ClosedJaxpr 表示部分应用的 jax.core.Jaxpr,并且是使用 jax.make_jaxpr() 检查 jaxpr 时获得的结果。它具有以下字段

  • jaxpr:是一个 jax.core.Jaxpr,表示函数的实际计算内容(如下所述)。

  • consts 是常量列表。

ClosedJaxpr 中最有趣的部分是实际的执行内容,表示为 jax.core.Jaxpr,使用以下语法打印

jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in  [Expr+] }

其中

  • jaxpr 的参数显示为由 ; 分隔的两个变量列表

    • 第一组变量是为代表已提升出的常量而引入的变量。这些变量称为 constvars,在 jax.core.ClosedJaxpr 中,consts 字段保存相应的值。

    • 第二组变量,称为 invars,对应于跟踪的 Python 函数的输入。

  • Eqn* 是方程列表,定义引用中间表达式的中间变量。每个方程定义一个或多个变量作为在某些原子表达式上应用 primitive 的结果。每个方程仅使用输入变量和先前方程定义的中间变量。

  • Expr+:是 jaxpr 的输出原子表达式(文字或变量)列表。

方程的打印方式如下

Eqn  ::= let Var+ = Primitive [ Param* ] Expr+

其中

  • Var+ 是要定义为一个 primitive 调用输出的一个或多个中间变量(某些 primitives 可以返回多个值)。

  • Expr+ 是一个或多个原子表达式,每个表达式要么是变量,要么是文字常量。特殊变量 unitvar 或文字 unit,打印为 *,表示计算的其余部分不需要的值,并且已被省略。也就是说,units 只是占位符。

  • Param* 是 primitive 的零个或多个命名参数,以方括号打印。每个参数都显示为 Name = Value

大多数 jaxpr primitives 都是一阶的(它们仅接受一个或多个 Expr 作为参数)

Primitive := add | sub | sin | mul | ...

最常见的 jaxpr primitives 记录在 jax.lax 模块中。

例如,这是为下面的函数 func1 生成的 jaxpr

from jax import make_jaxpr
import jax.numpy as jnp

def func1(first, second):
   temp = first + jnp.sin(second) * 3.
   return jnp.sum(temp)

print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

这里没有 constvars,ab 是输入变量,它们分别对应于 firstsecond 函数参数。标量文字 3.0 保留在内联。reduce_sum primitive 除了操作数 e 之外,还具有命名参数 axesinput_shape

请注意,即使调用 JAX 的程序的执行会构建 jaxpr,Python 级别的控制流和 Python 级别的函数也会正常执行。这意味着,仅仅因为 Python 程序包含函数和控制流,生成的 jaxpr 也不必包含控制流或高阶功能。

例如,当跟踪函数 func3 时,JAX 将内联对 inner 的调用和条件 if second.shape[0] > 4,并将生成与之前相同的 jaxpr

def func2(inner, first, second):
  temp = first + inner(second) * 3.
  return jnp.sum(temp)

def inner(second):
  if second.shape[0] > 4:
    return jnp.sin(second)
  else:
    assert False

def func3(first, second):
  return func2(inner, first, second)

print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

处理 pytrees#

在 jaxpr 中,没有元组类型;相反,primitives 接受多个输入并产生多个输出。当处理具有结构化输入或输出的函数时,JAX 将展平这些输入和输出,并在 jaxpr 中,它们将显示为输入和输出列表。有关更多详细信息,请参阅 Pytrees 教程。

例如,以下代码生成与您之前看到的 jaxpr 相同的 jaxpr(具有两个输入变量,每个变量对应于输入元组的一个元素)

def func4(arg):  # The `arg` is a pair.
  temp = arg[0] + jnp.sin(arg[1]) * 3.
  return jnp.sum(temp)

print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

常量变量 (vars)#

jaxpr 中的某些值是常量,因为它们的值不依赖于 jaxpr 的参数。当这些值是标量时,它们直接在 jaxpr 方程中表示。非标量数组常量而是被提升到顶层 jaxpr,在顶层 jaxpr 中,它们对应于常量变量(“constvars”)。这些 constvars 与其他 jaxpr 参数(“invars”)的不同之处仅在于簿记约定。

高阶 JAX primitives#

Jaxpr 包括几个高阶 JAX primitives。它们更复杂,因为它们包含子 jaxpr。

cond primitive (条件)#

JAX 跟踪正常的 Python 条件语句。要捕获用于动态执行的条件表达式,必须使用 jax.lax.switch()jax.lax.cond() 构造函数,它们的签名如下

lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B

lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B

两者都将在内部绑定一个名为 cond 的 primitive。cond primitive 在 jaxpr 中反映了 lax.switch() 的更通用签名:它接受一个整数,表示要执行的分支的索引(钳制到有效索引范围)。

例如

from jax import lax

def one_of_three(index, arg):
  return lax.switch(index, [lambda x: x + 1.,
                            lambda x: x - 2.,
                            lambda x: x + 3.],
                    arg)

print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    d:i32[] = clamp 0 c 2
    e:f32[] = cond[
      branches=(
        { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
        { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
        { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
      )
    ] d b
  in (e,) }

cond primitive 有许多参数

  • branches 是 jaxpr,对应于分支函数。在此示例中,这些函数每个都接受一个输入变量,对应于 x

  • linear 是一个布尔元组,由自动微分机制在内部使用,以编码哪些输入参数在线性条件中使用。

上面的 cond primitive 实例接受两个操作数。第一个 (d) 是分支索引,然后 b 是操作数 (arg),将传递给 branches 中由分支索引选择的 jaxpr。

另一个示例,使用 jax.lax.cond()

from jax import lax

def func7(arg):
  return lax.cond(arg >= 0.,
                  lambda xtrue: xtrue + 3.,
                  lambda xfalse: xfalse - 3.,
                  arg)

print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
    b:bool[] = ge a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
        { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
      )
    ] c a
  in (d,) }

在这种情况下,布尔谓词被转换为整数索引(0 或 1),并且 branches 是 jaxpr,它们对应于 false 和 true 分支函数,顺序如此。同样,每个函数都接受一个输入变量,分别对应于 xfalsextrue

以下示例显示了更复杂的情况,即分支函数的输入是元组,并且 false 分支函数包含一个常量 jnp.ones(1),该常量被提升为 constvar

def func8(arg1, arg2):  # Where `arg2` is a pair.
  return lax.cond(arg1 >= 0.,
                  lambda xtrue: xtrue[0],
                  lambda xfalse: jnp.array([1]) + xfalse[1],
                  arg2)

print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
    e:bool[] = ge b 0.0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    g:f32[1] = cond[
      branches=(
        { lambda ; h:i32[1] i:f32[1] j:f32[]. let
            k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
            l:f32[1] = add k j
          in (l,) }
        { lambda ; m_:i32[1] n:f32[1] o:f32[]. let  in (n,) }
      )
    ] f a c d
  in (g,) }

while primitive#

就像条件语句一样,Python 循环在跟踪期间内联。如果想捕获用于动态执行的循环,则必须使用几个特殊操作之一,jax.lax.while_loop() (一个 primitive) 和 jax.lax.fori_loop() (一个生成 while_loop primitive 的助手)

lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C

在上面的签名中,C 代表循环“carry”值的类型。例如,这是一个 fori_loop 示例

import numpy as np

def func10(arg, n):
  ones = jnp.ones(arg.shape)  # A constant.
  return lax.fori_loop(0, n,
                       lambda i, carry: carry + ones * 3. + arg,
                       arg + ones)

print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
    c:f32[16] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(16,)
      sharding=None
    ] 1.0
    d:f32[16] = add a c
    _:i32[] _:i32[] e:f32[16] = while[
      body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
          k:i32[] = add h 1
          l:f32[16] = mul f 3.0
          m:f32[16] = add j l
          n:f32[16] = add m g
        in (k, i, n) }
      body_nconsts=2
      cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
          r:bool[] = lt o p
        in (r,) }
      cond_nconsts=0
    ] c a 0 b d
  in (e,) }

while primitive 接受 5 个参数:c a 0 b d,如下所示

  • 用于 cond_jaxpr 的 0 个常量(因为 cond_nconsts 为 0)

  • 用于 body_jaxpr 的 2 个常量(ca

  • carry 初始值的 3 个参数

scan primitive#

JAX 支持对数组元素进行特殊形式的循环(具有静态已知形状)。迭代次数固定这一事实使得这种形式的循环易于反向微分。此类循环是使用 jax.lax.scan() 函数构建的

lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])

这是用 Haskell 类型签名 编写的:Cscan carry 的类型,A 是输入数组的元素类型,B 是输出数组的元素类型。

对于示例,请考虑下面的函数 func11

def func11(arr, extra):
  ones = jnp.ones(arr.shape)  #  A constant
  def body(carry, aelems):
    # carry: running dot-product of the two arrays
    # aelems: a pair with corresponding elements from the two arrays
    ae1, ae2 = aelems
    return (carry + ae1 * ae2 + extra, carry)
  return lax.scan(body, 0., (arr, ones))

print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
    c:f32[16] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(16,)
      sharding=None
    ] 1.0
    d:f32[] e:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
          j:f32[] = mul h i
          k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
          l:f32[] = add k j
          m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          n:f32[] = add l m
        in (n, g) }
      length=16
      linear=(False, False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=1
    ] b 0.0 a c
  in (d, e) }

linear 参数描述了每个输入变量是否保证在线性体中使用。一旦 scan 完成线性化,更多参数将是线性的。

scan primitive 接受 4 个参数:b 0.0 a c,其中

  • 一个是 body 的自由变量

  • 一个是 carry 的初始值

  • 接下来的 2 个是 scan 操作的数组

(p)jit primitive#

call primitive 来自 JIT 编译,它封装了一个子 jaxpr 以及指定应在哪个后端和设备上运行计算的参数。例如

from jax import jit

def func12(arg):
  @jit
  def inner(x):
    return x + arg * jnp.ones(1)  # Include a constant in the inner function.
  return arg + inner(arg - 2.)

print(make_jaxpr(func12)(1.))
{ lambda ; a:f32[]. let
    b:f32[] = sub a 2.0
    c:f32[1] = pjit[
      name=inner
      jaxpr={ lambda ; a:f32[] b:f32[]. let
          d:f32[1] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1,)
            sharding=None
          ] 1.0
          e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
          f:f32[1] = mul e d
          g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
          c:f32[1] = add g f
        in (c,) }
    ] a b
    h:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
    i:f32[1] = add h c
  in (i,) }