JAX PRNG 设计#

我们希望 PRNG 设计能够

  1. 富有表现力,即方便使用,并且不限制用户编写具有他们所需行为的数值程序的能力,

  2. 以与后端无关的方式实现可重现的程序执行,

  3. 具有@jit 编译边界和设备后端不变的语义,

  4. 通过 SIMD 硬件实现用于生成数组值的矢量化

  5. 可并行化,即它不会在没有显式数据依赖关系的随机函数调用之间增加排序约束,

  6. 可扩展到多副本、多核和分布式计算

  7. 符合 JAX 和 XLA 的语义及设计理念(这些最终受其他实际考虑因素的驱动)。

作为这些的推论,我们认为该设计应该是函数式的。另一个推论是,至少考虑到当前的硬件限制,我们将采用软件方式实现 PRNG。

TLDR JAX PRNG = Threefry 计数器 PRNG + 函数式数组导向的 分割模型

目录#

三种编程模型和玩具示例程序#

这是一个类似于 Numpy 程序中常用的有状态全局 PRNG 的玩具示例

def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
  global RNG
  RNG = RandomState(0)
  return foo()

为了在此处实现可重现性,我们需要控制 bar() 和 baz() 的评估顺序,即使它们之间没有明确的数据依赖关系。这种源于可重现性的排序要求(#2)违反了可并行性(#5),并且不符合 JAX 或 XLA 的函数式语义(#6),在这些语义中,子表达式可以按任何顺序进行评估。即使我们不要求可重现性,从而允许任何评估顺序,共享状态的更新需求仍然会阻碍调用之间的并行化(#5)。此外,由于同一个 PRNG 状态需要在 Python 和任何编译后的代码中进行访问和维护,因此该模型很可能导致在实现编译不变性(#3)和扩展到多个副本(#6)方面遇到工程挑战。最后,表现力受限(#1),因为 foo() 无法在不影响自身(隐式)PRNG 状态的情况下调用 bar() 或 baz()。

该模型是否支持矢量化(#4)取决于一些附加细节。在 Numpy 中,PRNG 矢量化受到顺序等价保证的限制

In [1]: rng = np.random.RandomState(0)

In [2]: rng.randn(2)
Out[2]: array([1.76405235, 0.40015721])

In [3]: rng = np.random.RandomState(0)

In [4]: np.stack([rng.randn() for _ in range(2)])
Out[4]: array([1.76405235, 0.40015721])

为了允许在生成数组的原始 PRNG 函数调用(例如,带有 shape 参数的 rand())中进行矢量化(#4),我们放弃了此顺序等价保证。此矢量化可以由本节讨论的任何三种编程模型支持,尽管它促使我们以计数器为基础的 PRNG 来实现,如下一节所述。

有状态 PRNG 用户编程模型并不理想。这是一个函数式模型,但缺少我们称之为分割的关键成分

def foo(rng_1):
   y, rng_2 = baz(rng_1)
   z, rng_3 = bar(rng_2)
   return y + z, rng_3

def bar(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def baz(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def main():
  foo(RandomState(0))

此模型明确地将 PRNG 状态通过所有生成随机值的函数(原始或非原始)进行传递:即,每个随机函数都必须接受并返回状态。现在,在 foo() 中,baz() 的调用与 bar() 的调用之间存在明确的数据依赖关系,因此数据流(以及因此的排序)是明确的,并且符合 JAX 现有的语义(#7),与前一个模型不同。这种显式传递也可以使语义对编译边界不变(#3)。

显式传递对程序员来说不方便。但更糟糕的是,它并没有真正提高表现力(#1):foo() 仍然无法在保持自身 PRNG 状态的情况下调用 bar() 或 baz()。在不知道其调用者或它们调用的子例程的情况下,函数必须在所有地方防御性地传入和传出 rng 状态。此外,它也没有改善并行化(#5)或扩展到多个副本(#6)的前景,因为所有内容仍然是顺序的,即使排序在函数式编程的意义上是明确的。

简而言之,通过显式传递状态使代码函数化不足以实现我们的表现力(#1)和性能(#5, #6)目标。

前两种模型中的关键问题是排序过多。为了减少顺序依赖性,我们使用函数式 可分割 PRNG。分割是一种机制,可以将一个新的 PRNG 状态“分叉”成两个 PRNG 状态,同时保持通常的可取的 PRNG 属性(两个新流在计算上是可并行化的,并产生独立的随机值,即它们表现得像 多流)。

def foo(rng_1):
   rng_2, rng_3 = split(rng_1, 2)
   return bar(rng_2) + baz(rng_3)

def bar(x, rng):
  return rand(rng, (3, 4))

def baz(x, rng):
  return rand(rng, (3, 4))

def main():
  foo(RandomState(0))

一些要点

  1. bar() 和 baz() 的调用之间没有顺序依赖关系,它们可以按任何顺序评估而不会影响结果的值,这解决了剩余的性能目标(#5, #6),

  2. 函数不需要返回更新的 PRNG 版本,并且可以轻松地调用随机子例程而不会影响现有的 PRNG 状态,从而提高了从其他函数式模型获得的表现力(#1)。

示例并未显示,但作为选项(2)的结果,推进 PRNG 状态的唯一方法是调用 split()。也就是说,我们有两种方法可以实现(1),它们之间的区别在于它们是使用户程序承担显式调用 split() 的负担,如以上示例所示,还是承担显式传递的负担。我们更喜欢前者,即显式分割的版本,因为我们可以轻松地基于它来实现显式传递的版本。

设计#

我们可以使用基于计数器的 PRNG 设计,特别是 Threefry 哈希函数,如 Parallel random numbers: as easy as 1, 2, 3 中所述。我们使用计数器来实现高效的矢量化:对于给定的密钥,我们可以通过将哈希函数映射到一系列整数 [k + 1, …, k + sample_size] 来以矢量化方式生成数组值。我们使用密钥以及哈希函数来实现 可分割 PRNG:也就是说,分割是根据现有密钥生成两个新密钥的一种方法。

type Sample = Int256
type Key = Sample  -- important identification for splitting
type Count = Int32

hash :: Key -> Count -> Int256  -- output type equal to Key and Sample

split :: Key -> (Key, Key)
split key = (hash key 0, hash key 1)

draw_samples :: Key -> Int -> [Sample]
draw_samples key n = map (hash key) [1..n]

令人惊讶的是,抽取样本与分割非常相似!关键在于输出类型的差异(即使类型被标识):一种情况下,该值用于形成感兴趣的随机样本(例如,将随机位转换为表示随机正态分布的 Float),而在另一种情况下,该值将用作进一步哈希的密钥。

哈希函数参数(Key 和 Count 类型)之间的不对称性在于,后者是平凡且计算成本低的,可以通过任意数量来推进,因为我们只需要增加整数值,而前者仅通过哈希来推进。这就是我们使用计数器参数进行矢量化的原因。

更真实的示例用户程序#

当步骤需要 PRNG 时(可能是为了 dropout 或 VAE 训练),主机上的训练循环可能如下所示

rng = lax.rng.new_rng()
for i in xrange(num_steps):
  rng, rng_input = lax.rng.split(rng)
  params = compiled_update(rng_input, params, next(batches))

请注意,我们让用户承担了显式分割 rng 的负担,但 rng 不需要从代码中返回。

以下是我们如何使用此 PRNG 模型与 stax 神经网络构建库来实现 dropout

def Dropout(rate, mode='train'):
  def init_fun(input_shape):
    return input_shape, ()
  def apply_fun(rng, params, inputs):
    if mode == 'train':
      keep = lax.random.bernoulli(rng, rate, inputs.shape)
      return np.where(keep, inputs / rate, 0)
    else:
      return inputs
  return init_fun, apply_fun

此处的 rng 值只是用于哈希的密钥,而不是特殊对象。rng 参数传递给每个 apply_fun,因此需要在串行和并行组合器中通过分割来处理。

def serial(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    for rng, param, apply_fun in zip(rngs, params, apply_funs):
      inputs = apply_fun(rng, param, inputs)
    return inputs
  return init_fun, apply_fun

def parallel(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
  return init_fun, apply_fun

这里我们使用了一个简单的扩展版本的 split,它可以生成多个副本。

权衡与替代方案#

  1. 我们没有利用任何设备硬件 PRNG

    • 目前,我们对所有后端硬件 PRNG 的状态的控制不足。

    • 即使我们能做到,它也会依赖于后端,并且我们可能需要为确保确定性排序和因此的可重现性而引入随机调用之间的顺序依赖关系。

    • 我们不知道有任何工作负载会使软件 PRNG 成为瓶颈。

    • 我们可以考虑提供一个附加 API,允许用户访问硬件 PRNG,以换取他们放弃其他需求(如严格的可重现性)。

  2. 我们放弃了顺序等价保证,即一次创建随机数组会产生与一次创建一个展平的随机数组中的随机元素相同的值。

    • 此属性很可能与矢量化(高优先级)不兼容。

    • 我们不知道有任何用户或示例对此属性很重要。

    • 用户可以在此 API 之上编写一个层来提供此保证。

  3. 我们无法完全遵循 numpy.random API。