伪随机数#
如果所有因糟糕的 `rand()` 导致结果存疑的科学论文都从图书馆书架上消失,那么每个书架上都会留下一个大约有拳头大小的空隙。 - Numerical Recipes
在本节中,我们将重点介绍 `jax.random` 模块以及伪随机数生成 (PRNG);也就是说,通过算法生成数字序列的过程,其属性近似于从适当的分布中采样的随机数序列的属性。
PRNG 生成的序列并非真正随机,因为它们实际上由它们的初始值(通常称为 `seed`)决定,并且每次随机采样步骤都是从一个样本传递到下一个样本的某个 `state` 的确定性函数。
伪随机数生成是任何机器学习或科学计算框架的基本组成部分。通常,JAX 努力与 NumPy 兼容,但伪随机数生成是一个显著的例外。
为了更好地理解 JAX 和 NumPy 在随机数生成方面的不同方法,我们将在本节中讨论这两种方法。
NumPy 中的随机数#
NumPy 通过 `numpy.random` 模块原生支持伪随机数生成。在 NumPy 中,伪随机数生成基于全局 `state`,可以使用 `numpy.random.seed()` 将其设置为确定性初始条件。
import numpy as np
np.random.seed(0)
重复调用 NumPy 的有状态伪随机数生成器 (PRNG) 会改变全局状态,并产生一个伪随机数流。
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439
在底层,NumPy 使用 Mersenne Twister PRNG 来驱动其伪随机函数。该 PRNG 的周期为 \(2^{19937}-1\),在任何时刻都可以用 624 个 32 位无符号整数和一个指示“熵”已使用多少的指针来描述。
您可以使用以下命令检查状态的内容。
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
每次调用随机函数时,`state` 都会更新。
np.random.seed(0)
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
NumPy 允许您在单个函数调用中对单个数字或整个数字向量进行采样。例如,您可以通过以下方式从均匀分布中采样一个包含 3 个标量的向量:
np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135 0.71518937 0.60276338]
NumPy 提供了一个 *序列等价保证*,这意味着单独连续采样 N 个数字与采样 N 个数字的向量会产生相同的伪随机序列。
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
JAX 中的随机数#
JAX 的随机数生成与 NumPy 的随机数生成在重要方面有所不同,因为 NumPy 的 PRNG 设计使得同时保证一系列期望的属性变得困难。具体来说,在 JAX 中,我们希望 PRNG 生成是
可复现的,
可并行化的,
可向量化的。
我们将在下面讨论原因。首先,我们将重点关注基于全局状态的 PRNG 设计的含义。考虑以下代码:
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
1.9791922366721637
`foo` 函数对从均匀分布中采样的两个标量求和。
此代码的输出只有在我们假设 `bar()` 和 `baz()` 的执行顺序可预测的情况下才能满足要求 #1。这在 NumPy 中不是问题,NumPy 总是按照 Python 解释器定义的顺序评估代码。然而,在 JAX 中,这是一个更棘手的问题:为了高效执行,我们希望 JIT 编译器能够自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,需要每个进程同步全局状态会阻碍执行效率。
显式的随机状态#
为了避免这些问题,JAX 避免使用隐式的全局随机状态,而是通过一个随机 `key` 来显式跟踪状态。
from jax import random
key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]
注意
本节使用由 `jax.random.key()` 生成的新式类型化 PRNG 密钥,而不是由 `jax.random.PRNGKey()` 生成的旧式原始 PRNG 密钥。有关详细信息,请参阅 JEP 9263:类型化密钥和可插拔 RNG。
密钥是一个具有特殊 dtype 的数组,对应于正在使用的特定 PRNG 实现;在默认实现中,每个密钥都由一对 `uint32` 值支持。
密钥实际上是 NumPy 隐藏状态对象的替代品,但我们将其显式传递给 `jax.random()` 函数。重要的是,随机函数会消耗密钥,但不会修改它:将相同的密钥对象馈送到随机函数总是会生成相同的样本。
print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616
即使使用不同的 `random` API 重用相同的密钥,也可能导致输出相关,这通常是不可取的。
经验法则:永远不要重复使用密钥(除非您想要相同的输出)。重复使用相同的状态会导致*悲伤*和*单调*,剥夺最终用户*充满生命力的混沌*。
JAX 使用现代 Threefry 基于计数器的 PRNG,它是可分裂的。也就是说,其设计允许我们将 PRNG 状态分叉成新的 PRNG,用于并行随机生成。为了生成不同且独立的样本,您必须在将其传递给随机函数之前显式地 `split()` 密钥。
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: 0.6057640314102173
draw 1: -0.21089035272598267
draw 2: -0.3948981463909149
(此处调用 `del` 不是必需的,但我们这样做是为了强调一旦消耗了密钥,就不应该重复使用它。)
`jax.random.split()` 是一个确定性函数,它将一个 `key` 转换为几个独立(在伪随机意义上)的密钥。我们将其中一个输出保留为 `new_key`,并可以安全地将唯一的额外密钥(称为 `subkey`)用作随机函数的输入,然后永远丢弃它。如果您想从正态分布中获取另一个样本,您将再次分割 `key`,依此类推:关键点在于您永远不会使用同一个密钥两次。
将 `split(key)` 的哪个部分称为 `key`,哪个称为 `subkey` 并不重要。它们都是具有相等地位的独立密钥。key/subkey 命名约定是一个典型的用法模式,有助于跟踪密钥的消耗方式:subkey 用于立即由随机函数消耗,而 key 则保留用于稍后生成更多随机性。
通常,上面的示例会简洁地写成:
key, subkey = random.split(key)
这会自动丢弃旧的密钥。值得注意的是,`split()` 可以创建任意数量的密钥,而不仅仅是 2 个。
key, *forty_two_subkeys = random.split(key, num=43)
缺乏序列等价性#
NumPy 和 JAX 的随机模块之间的另一个区别与上面提到的序列等价性保证有关。
与 NumPy 一样,JAX 的随机模块也允许对数字向量进行采样。但是,JAX 不提供序列等价保证,因为这样做会干扰 SIMD 硬件上的向量化(上面要求 #3)。
在下面的示例中,使用三个子密钥单独从正态分布中采样 3 个值,与使用单个密钥并指定 `shape=(3,)` 的结果不同。
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [0.07592554 0.60576403 0.4323065 ]
all at once: [-0.02830462 0.46713185 0.29570296]
缺乏序列等价性让我们能够更高效地编写代码;例如,我们不必通过顺序循环生成上面的 `sequence`,而是可以使用 `jax.vmap()` 以向量化的方式计算相同的结果。
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [0.07592554 0.60576403 0.4323065 ]
后续步骤#
有关 JAX 随机数的更多信息,请参阅 `jax.random` 模块的文档。如果您对 JAX 随机数生成器的设计细节感兴趣,请参阅 JAX PRNG 设计。