伪随机数#
如果所有因糟糕的
rand()
函数而导致结果存疑的科学论文从图书馆书架上消失,每个书架上都会出现一个拳头大小的空缺。 - Numerical Recipes
本节我们将重点介绍 jax.random
和伪随机数生成(PRNG);即,通过算法生成一系列数字,这些数字的属性近似于从适当分布中采样的随机数序列的属性。
PRNG 生成的序列并非真正的随机,因为它们实际上由其初始值(通常称为 seed
)确定,并且随机采样的每一步都是从一个样本传递到下一个样本的某个 state
的确定性函数。
伪随机数生成是任何机器学习或科学计算框架的重要组成部分。通常,JAX 力求与 NumPy 兼容,但伪随机数生成是一个显著的例外。
为了更好地理解 JAX 和 NumPy 在随机数生成方法上的区别,我们将在本节中讨论这两种方法。
NumPy 中的随机数#
NumPy 的 numpy.random
模块原生支持伪随机数生成。在 NumPy 中,伪随机数生成基于一个全局 state
,该状态可以使用 numpy.random.seed()
设置为确定性的初始条件。
import numpy as np
np.random.seed(0)
重复调用 NumPy 的有状态伪随机数生成器(PRNG)会改变全局状态并生成一串伪随机数。
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439
在底层,NumPy 使用 Mersenne Twister PRNG 来驱动其伪随机函数。该 PRNG 的周期为 \(2^{19937}-1\),在任何时候都可以用 624 个 32 位无符号整数和一个表示已使用了多少“熵”的位置来描述。
您可以使用以下命令检查状态内容。
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
state
会被每次对随机函数的调用更新
np.random.seed(0)
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
NumPy 允许您通过单次函数调用采样单个数字或整个数字向量。例如,您可以从均匀分布中采样一个包含 3 个标量的向量,方法是:
np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135 0.71518937 0.60276338]
NumPy 提供序列等价保证,这意味着连续单独采样 N 个数字或采样一个包含 N 个数字的向量会产生相同的伪随机序列。
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
JAX 中的随机数#
JAX 的随机数生成与 NumPy 的在重要方面有所不同,因为 NumPy 的 PRNG 设计使其难以同时保证许多期望的特性。具体而言,在 JAX 中,我们希望 PRNG 生成是:
可复现的,
可并行化的,
可向量化的。
我们将在下文讨论原因。首先,我们将重点关注基于全局状态的 PRNG 设计的影响。考虑以下代码:
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
1.9791922366721637
函数 foo
将从均匀分布中采样的两个标量求和。
仅当我们假设 bar()
和 baz()
的执行顺序是可预测的,此代码的输出才能满足要求 #1。这在 NumPy 中不是问题,NumPy 总是按照 Python 解释器定义的顺序评估代码。然而,在 JAX 中,这更具问题:为了高效执行,我们希望 JIT 编译器能够自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,每个进程同步全局状态的需要会阻碍执行效率。
显式随机状态#
为避免这些问题,JAX 避免隐式全局随机状态,而是通过随机 key
显式跟踪状态。
from jax import random
key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]
注意
本节使用 jax.random.key()
生成的新式类型化 PRNG 密钥,而不是 jax.random.PRNGKey()
生成的旧式原始 PRNG 密钥。有关详细信息,请参阅 JEP 9263:类型化密钥与可插拔 RNGs。
密钥是一个具有特殊 `dtype` 的数组,对应于所使用的特定 PRNG 实现;在默认实现中,每个密钥由一对 uint32
值支持。
该密钥实际上是 NumPy 隐藏状态对象的替代品,但我们将其显式传递给 jax.random()
函数。重要的是,随机函数会消耗密钥,但不会修改它:将相同的密钥对象传递给随机函数将始终生成相同的样本。
print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616
重复使用相同的密钥,即使使用不同的 random
API,也可能导致相关的输出,这通常是不希望的。
经验法则是:绝不要重复使用密钥(除非你想要相同的输出)。重复使用相同的状态会导致悲伤和单调,剥夺最终用户的活力四射的混沌。
JAX 使用现代的 Threefry 基于计数器的可拆分 PRNG。也就是说,其设计允许我们将 PRNG 状态分裂成新的 PRNG,用于并行随机生成。为了生成不同且独立的样本,您必须在将密钥传递给随机函数之前显式地 split()
密钥:
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: 0.6057640314102173
draw 1: -0.21089035272598267
draw 2: -0.3948981463909149
(这里不需要调用 del
,但我们这样做是为了强调密钥一旦使用后不应重复使用。)
jax.random.split()
是一个确定性函数,它将一个 key
转换为几个独立的(在伪随机性意义上)密钥。我们将其中一个输出保留为 new_key
,并且可以安全地将唯一的额外密钥(称为 subkey
)用作随机函数的输入,然后永远丢弃它。如果您想从正态分布中获得另一个样本,您将再次分裂 key
,依此类推:关键点是您绝不要两次使用同一个密钥。
我们将 split(key)
的输出中的哪一部分称为 key
,哪一部分称为 subkey
并不重要。它们都是具有同等地位的独立密钥。key/subkey 命名约定是一种典型的使用模式,有助于跟踪密钥的消耗方式:subkey 旨在立即被随机函数使用,而 key 则被保留用于后续生成更多随机性。
通常,上面的示例可以简洁地写成:
key, subkey = random.split(key)
这会自动丢弃旧密钥。值得注意的是,split()
可以创建你需要的任意数量的密钥,而不仅仅是 2 个。
key, *forty_two_subkeys = random.split(key, num=43)
缺乏序列等价性#
NumPy 和 JAX 随机模块的另一个区别与上述提到的序列等价保证有关。
与 NumPy 中一样,JAX 的随机模块也允许采样数字向量。但是,JAX 不提供序列等价保证,因为这样做会干扰 SIMD 硬件上的向量化(上述要求 #3)。
在下面的示例中,使用三个子密钥单独从正态分布中采样 3 个值与使用单个密钥并指定 shape=(3,)
会得到不同的结果:
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [0.07592554 0.60576403 0.4323065 ]
all at once: [-0.02830462 0.46713185 0.29570296]
缺乏序列等价性使我们能够更高效地编写代码;例如,我们可以使用 jax.vmap()
以向量化方式计算相同的结果,而不是通过顺序循环生成上述 sequence
:
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [0.07592554 0.60576403 0.4323065 ]
后续步骤#
有关 JAX 随机数的更多信息,请参阅 jax.random
模块的文档。如果您对 JAX 随机数生成器设计的细节感兴趣,请参阅 JAX PRNG 设计。