JAX PRNG 设计#

我们希望 PRNG 设计能够

  1. 具有 表达性 ,即使用方便,并且不会限制用户编写具有他们想要的精确行为的数值程序的能力,

  2. 以独立于后端的方式实现 可重现的 程序执行,

  3. 具有 不随 @jit 编译边界和设备后端而变化的 语义,

  4. 能够使用 SIMD 硬件 向量化以生成数组值

  5. 可并行化的 ,因为它不会在随机函数调用之间添加序列约束,否则这些调用将没有数据依赖性,

  6. 扩展到 多副本、多核和分布式计算

  7. 符合 JAX 和 XLA 语义 以及设计理念(最终由其他实际考虑因素驱动)。

作为这些的必然结果,我们认为设计应该是函数式的。另一个必然结果是,至少在当前的硬件约束下,我们将在软件中进行 PRNG。

TLDR JAX PRNG = Threefry 计数器 PRNG + 面向函数式数组的 拆分模型

目录#

三种编程模型和玩具示例程序#

这是一个 有状态全局 PRNG 的玩具示例,类似于 Numpy 程序中常用的 PRNG

def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
  global RNG
  RNG = RandomState(0)
  return foo()

为了在此处实现可重现性,我们需要控制 bar() 和 baz() 的评估顺序,即使两者之间没有显式的数据依赖性。这种源于可重现性 (#2) 的排序要求违反了可并行性 (#5),并且不符合 JAX 或 XLA 的函数式语义 (#6),在函数式语义中,子表达式可以以任何顺序进行评估。即使我们不要求可重现性,因此允许任何评估顺序,跨调用的并行化 (#5) 仍然会因需要更新共享状态而变得困难。此外,由于需要在 Python 和任何编译代码中访问和维护相同的 PRNG 状态,因此该模型可能会导致工程挑战,以实现编译不变性 (#3) 和扩展到多个副本 (#6)。最后,表达性受到限制 (#1),因为 foo() 无法在不影响其自身(隐式)PRNG 状态的情况下调用 bar() 或 baz()。

该模型是否支持向量化 (#4) 取决于一些额外的细节。在 Numpy 中,PRNG 向量化受到 顺序等效保证 的限制

In [1]: rng = np.random.RandomState(0)

In [2]: rng.randn(2)
Out[2]: array([1.76405235, 0.40015721])

In [3]: rng = np.random.RandomState(0)

In [4]: np.stack([rng.randn() for _ in range(2)])
Out[4]: array([1.76405235, 0.40015721])

为了在生成数组的原始 PRNG 函数调用中允许向量化 (#4)(例如,对带有 shape 参数的 rand()),我们放弃了这种顺序等效保证。这种向量化可以由本节中讨论的三种编程模型中的任何一种支持,尽管它促使我们根据下一节中描述的基于计数器的 PRNG 来实现。

有状态 PRNG 用户编程模型没有前景。这是一个函数式模型的示例,但缺少我们称之为拆分的关键要素

def foo(rng_1):
   y, rng_2 = baz(rng_1)
   z, rng_3 = bar(rng_2)
   return y + z, rng_3

def bar(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def baz(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def main():
  foo(RandomState(0))

此模型显式地将 PRNG 状态传递到所有生成随机值的函数(原始函数或非原始函数):也就是说,每个随机函数都必须接受和返回状态。现在,在 foo() 中调用 baz() 和调用 bar() 之间存在显式的数据依赖性,因此数据流(以及随后的排序)变得显式,并且与 JAX 现有的语义 (#7) 相符,这与之前的模型不同。这种显式线程化还可以使语义不随编译边界而变化 (#3)。

显式线程化对于程序员来说很不方便。但更糟糕的是,它实际上并没有提高表达性 (#1):foo() 仍然无法在保持自身 PRNG 状态的情况下调用 bar() 或 baz()。如果不了解其调用者或它们调用的子例程,函数必须防御性地在所有地方传入和返回 rng 状态。此外,它也没有改善并行化 (#5) 或扩展到多个副本 (#6) 的前景,因为一切仍然是顺序的,即使排序在函数式编程意义上是显式的。

简而言之,通过显式线程化状态使代码具有函数式性不足以实现我们的表达性 (#1) 和性能 (#5, #6) 目标。

先前两个模型中的关键问题是排序过多。为了减少顺序依赖的数量,我们使用 函数式 可拆分 PRNG 。拆分是一种将新的 PRNG 状态“fork”成两个 PRNG 状态的机制,同时保持通常期望的 PRNG 属性(两个新的流在计算上是可并行化的,并且产生独立的随机值,即它们的行为类似于 多流 )。

def foo(rng_1):
   rng_2, rng_3 = split(rng_1, 2)
   return bar(rng_2) + baz(rng_3)

def bar(x, rng):
  return rand(rng, (3, 4))

def baz(x, rng):
  return rand(rng, (3, 4))

def main():
  foo(RandomState(0))

需要注意的一些要点

  1. 调用 bar() 和 baz() 之间没有顺序依赖性,它们可以以任何顺序进行评估,而不会影响结果的值,这解决了剩余的性能目标 (#5, #6),

  2. 函数不需要返回 PRNG 的更新版本,并且可以直接调用随机子例程,而不会影响现有的 PRNG 状态,从而提高了其他函数式模型的表达性 (#1)。

该示例没有显示这一点,但作为选择 (2) 的结果,推进 PRNG 状态的唯一方法是调用 split()。也就是说,我们有两种方法来实现 (1),它们的区别在于它们是否使用显式调用 split() 来加重用户程序的负担,如上面的示例所示,或者使用显式线程化来加重用户程序的负担。我们更喜欢前者,即具有显式拆分的版本,因为我们可以很容易地根据它来实现显式线程化版本。

设计#

我们可以使用 基于计数器的 PRNG 设计,特别是 Threefry 哈希函数,如 并行随机数:像 1、2、3 一样简单 中所述。我们使用计数器来实现高效的向量化:对于给定的键,我们可以通过将哈希函数映射到整数范围 [k + 1, …, k + sample_size] 来以向量化的方式生成一系列值。我们将键与哈希函数一起使用来实现 可拆分 PRNG :也就是说,拆分是一种从现有键生成两个新键的方法。

type Sample = Int256
type Key = Sample  -- important identification for splitting
type Count = Int32

hash :: Key -> Count -> Int256  -- output type equal to Key and Sample

split :: Key -> (Key, Key)
split key = (hash key 0, hash key 1)

draw_samples :: Key -> Int -> [Sample]
draw_samples key n = map (hash key) [1..n]

令人惊讶的是,抽取样本与拆分非常相似!关键在于输出类型的差异(即使类型被标识):在一种情况下,该值将用于形成感兴趣的随机样本(例如,将随机位转换为表示随机正态分布的浮点数),而在另一种情况下,该值将用作进一步哈希的键。

哈希函数参数(类型为 Key 和 Count)的不对称性在于,后者是微不足道的,并且在计算上以任意量推进的成本很低,因为我们只需要增加整数值,而前者仅通过哈希推进。这就是为什么我们使用 count 参数进行向量化的原因。

更真实的示例用户程序#

以下是主机上的训练循环可能的样子,当步骤需要 PRNG 时(可能是用于 dropout 或 VAE 训练)

rng = lax.rng.new_rng()
for i in xrange(num_steps):
  rng, rng_input = lax.rng.split(rng)
  params = compiled_update(rng_input, params, next(batches))

请注意,我们正在使用户承担显式拆分 rng 的负担,但 rng 不需要从代码中返回。

以下是我们如何将此 PRNG 模型与 stax 神经网络构建器库一起使用来实现 dropout

def Dropout(rate, mode='train'):
  def init_fun(input_shape):
    return input_shape, ()
  def apply_fun(rng, params, inputs):
    if mode == 'train':
      keep = lax.random.bernoulli(rng, rate, inputs.shape)
      return np.where(keep, inputs / rate, 0)
    else:
      return inputs
  return init_fun, apply_fun

此处的 rng 值只是用于哈希的键,而不是特殊对象。 rng 参数传递给每个 apply_fun,因此需要在串行和并行组合器中使用拆分来处理它

def serial(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    for rng, param, apply_fun in zip(rngs, params, apply_funs):
      inputs = apply_fun(rng, param, inputs)
    return inputs
  return init_fun, apply_fun

def parallel(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
  return init_fun, apply_fun

这里我们使用的是 split 的简单扩展版本,可以生成多个副本。

权衡和替代方案#

  1. 我们没有利用任何设备硬件 PRNG

    • 我们目前没有足够的控制权来控制所有后端的硬件 PRNG 状态。

    • 即使我们有,它也会依赖于后端,并且我们可能需要在随机调用之间引入顺序依赖性,以确保确定性排序,从而确保可重现性。

    • 我们不知道任何软件 PRNG 会成为瓶颈的工作负载。

    • 我们可以考虑提供额外的 API,允许用户访问硬件 PRNG,这些用户希望放弃其他期望的特性(例如严格的可重现性)。

  2. 我们放弃了顺序等效保证,即在一个调用中创建随机数组会产生与一次创建一个扁平数组的随机元素相同的值。

    • 此属性可能与向量化(高优先级)不兼容。

    • 我们不知道任何用户或示例认为此属性很重要。

    • 用户可以在此 API 之上编写一个层来提供此保证。

  3. 我们无法完全遵循 numpy.random API。