伪随机数#

如果所有因为使用了糟糕的 rand() 函数而导致结果存疑的科学论文都从图书馆书架上消失,那么每个书架上都会留下一个拳头大小的空隙。 - 《数值配方》(Numerical Recipes)

在本节中,我们将重点介绍 jax.random 和伪随机数生成 (PRNG);也就是说,通过算法生成数字序列的过程,这些序列的属性近似于从适当分布中采样得到的随机数序列的属性。

PRNG 生成的序列并非真正的随机,因为它们实际上是由初始值(通常称为 seed,即种子)决定的。每次随机采样步骤都是某个 state(状态)的确定性函数,该状态会从一个样本传递到下一个样本。

伪随机数生成是任何机器学习或科学计算框架的重要组成部分。通常,JAX 力求与 NumPy 兼容,但伪随机数生成是一个显著的例外。

为了更好地理解 JAX 和 NumPy 在随机数生成方面所采取方法的区别,我们将在本节中讨论这两种方法。

NumPy 中的随机数#

NumPy 通过 numpy.random 模块原生支持伪随机数生成。在 NumPy 中,伪随机数生成基于一个全局 state(状态),可以使用 numpy.random.seed() 将其设置为确定性的初始条件。

import numpy as np
np.random.seed(0)

重复调用 NumPy 的有状态伪随机数生成器 (PRNG) 会改变全局状态,并产生一系列伪随机数。

print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439

在底层,NumPy 使用 梅森旋转算法 (Mersenne Twister) PRNG 来驱动其伪随机函数。该 PRNG 的周期为 \(2^{19937}-1\),在任何时刻都可以由 624 个 32 位无符号整数以及一个表示已消耗了多少“熵”的位置来描述。

您可以使用以下命令检查该状态的内容。

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

state 会随着每次调用随机函数而更新。

np.random.seed(0)
print_truncated_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

NumPy 允许您在单次函数调用中对单个数字或整个数字向量进行采样。例如,您可以通过以下方式从均匀分布中采样 3 个标量的向量:

np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135  0.71518937 0.60276338]

NumPy 提供顺序等价保证 (sequential equivalent guarantee),这意味着连续采样 N 个数字或采样一个包含 N 个数字的向量,会产生相同的伪随机序列。

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]

JAX 中的随机数#

JAX 的随机数生成在很多重要方面与 NumPy 不同,因为 NumPy 的 PRNG 设计很难同时保证多种理想属性。具体而言,在 JAX 中,我们希望 PRNG 生成具有以下特性:

  1. 可复现性 (reproducible),

  2. 可并行化 (parallelizable),

  3. 可向量化 (vectorisable)。

我们将在下文中讨论原因。首先,我们将重点讨论基于全局状态的 PRNG 设计所带来的影响。考虑以下代码:

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())
1.9791922366721637

函数 foo 对从均匀分布中采样的两个标量进行求和。

如果假设 bar()baz() 的执行顺序是可预测的,那么此代码的输出才能满足需求 #1。在 NumPy 中这不是问题,因为 NumPy 总是按照 Python 解释器定义的顺序执行代码。然而,在 JAX 中,这会带来更多问题:为了高效执行,我们希望 JIT 编译器能够自由地重排、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,如果需要每个进程同步全局状态,执行效率将会受到影响。

显式随机状态#

为了避免这些问题,JAX 避免使用隐式的全局随机状态,而是通过随机 key(键)显式跟踪状态。

from jax import random

key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]

注意

本节使用的是由 jax.random.key() 生成的新式类型化 PRNG 键,而不是由 jax.random.PRNGKey() 生成的旧式原始 PRNG 键。有关详细信息,请参阅 JEP 9263: 类型化键与可插拔 RNG

键是一个具有特殊数据类型的数组,对应于所使用的特定 PRNG 实现;在默认实现中,每个键由一对 uint32 值支持。

该键实际上充当了 NumPy 隐藏状态对象的替代品,但我们将其显式传递给 jax.random() 函数。重要的是,随机函数会消耗键,但不会修改它:将相同的键对象输入到随机函数中,将始终生成相同的样本。

print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616

重复使用同一个键,即使使用不同的 random API,也可能导致相关输出,这通常是不希望看到的。

经验法则是:永远不要重用键(除非您想要完全相同的输出)。重用同一个状态会导致悲伤单调,使最终用户失去赋予生命的混沌

JAX 使用一种现代的 Threefry 基于计数器的 PRNG,它是可拆分的。也就是说,其设计允许我们将 PRNG 状态分叉为新的 PRNG,以用于并行随机生成。为了生成不同且独立的样本,您必须在将键传递给随机函数之前,显式地 split()(拆分)该键。

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: 0.6057640314102173
draw 1: -0.21089035272598267
draw 2: -0.3948981463909149

(此处不需要调用 del,但我们这样做是为了强调键一旦被消耗就不应再被重用。)

jax.random.split() 是一个确定性函数,它将一个 key 转换为多个相互独立(在伪随机意义上)的键。我们保留其中一个输出作为 new_key,并将唯一的额外键(称为 subkey,即子键)作为输入安全地传给随机函数,然后将其丢弃。如果您想从正态分布中获取另一个样本,您可以再次拆分 key,依此类推:关键在于您从不重复使用同一个键。

我们将 split(key) 输出的哪一部分称为 key,哪一部分称为 subkey 并不重要。它们都是地位平等的独立键。“键/子键”命名约定是一种典型的使用模式,有助于跟踪键是如何被消耗的:子键注定会被随机函数立即消耗,而键则被保留下来用于稍后生成更多的随机性。

通常,上述示例可以简写为:

key, subkey = random.split(key)

它会自动丢弃旧键。值得注意的是,split() 可以创建您所需的任意多个键,而不仅仅是 2 个。

key, *forty_two_subkeys = random.split(key, num=43)

缺乏顺序等价性#

NumPy 和 JAX 随机模块之间的另一个区别涉及上述提到的顺序等价保证。

与 NumPy 一样,JAX 的随机模块也允许对数字向量进行采样。然而,JAX 不提供顺序等价保证,因为这样做会干扰 SIMD 硬件上的向量化(上述需求 #3)。

在下面的示例中,使用三个子键分别从正态分布中采样 3 个值,其结果与使用单个键并指定 shape=(3,) 的结果不同。

key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [0.07592554 0.60576403 0.4323065 ]
all at once:  [-0.02830462  0.46713185  0.29570296]

缺乏顺序等价性使我们能够更自由地编写高效代码;例如,我们无需通过顺序循环来生成上述的 sequence,而是可以使用 jax.vmap() 以向量化方式计算出相同的结果。

import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [0.07592554 0.60576403 0.4323065 ]

后续步骤#

有关 JAX 随机数的更多信息,请参阅 jax.random 模块的文档。如果您对 JAX 随机数生成器的设计细节感兴趣,请参阅 JAX PRNG 设计