伪随机数#
如果所有由于错误的
rand()
而导致结果存疑的科学论文从图书馆书架上消失,那么每个书架上都会有一个大约拳头大小的空隙。 - 数值食谱
在本节中,我们重点介绍 jax.random
和伪随机数生成 (PRNG);也就是说,算法生成数字序列的过程,这些序列的属性近似于从适当分布中采样的随机数序列的属性。
PRNG 生成的序列不是真正的随机序列,因为它们实际上由其初始值决定,初始值通常称为 seed
,并且随机采样的每个步骤都是某个 state
的确定性函数,该状态从一个样本传递到下一个样本。
伪随机数生成是任何机器学习或科学计算框架的重要组成部分。一般来说,JAX 努力与 NumPy 兼容,但伪随机数生成是一个明显的例外。
为了更好地理解 JAX 和 NumPy 在随机数生成方面采取的方法之间的差异,我们将在本节中讨论这两种方法。
NumPy 中的随机数#
NumPy 的 numpy.random
模块原生支持伪随机数生成。在 NumPy 中,伪随机数生成基于全局 state
,可以使用 numpy.random.seed()
将全局状态设置为确定性初始条件。
import numpy as np
np.random.seed(0)
重复调用 NumPy 的有状态伪随机数生成器 (PRNG) 会改变全局状态并产生伪随机数流
print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439
在底层,NumPy 使用 梅森旋转 PRNG 为其伪随机函数提供动力。PRNG 的周期为 \(2^{19937}-1\),并且在任何时候都可以用 624 个 32 位无符号整数和一个位置来描述,该位置指示已用完多少“熵”。
您可以使用以下命令检查状态的内容。
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
state
会在每次调用随机函数时更新
np.random.seed(0)
print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
NumPy 允许您在单个函数调用中采样单个数字或数字的整个向量。例如,您可以从均匀分布中采样一个包含 3 个标量的向量,方法是
np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135 0.71518937 0.60276338]
NumPy 提供顺序等价性保证,这意味着连续单独采样 N 个数字或采样 N 个数字的向量会导致相同的伪随机序列
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
JAX 中的随机数#
JAX 的随机数生成在重要方面与 NumPy 的不同,因为 NumPy 的 PRNG 设计使其难以同时保证许多理想的属性。具体来说,在 JAX 中,我们希望 PRNG 生成具有以下特性:
可重现性,
可并行化,
可向量化。
我们将在下面讨论原因。首先,我们将重点关注基于全局状态的 PRNG 设计的含义。考虑以下代码
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
1.9791922366721637
函数 foo
对从均匀分布中采样的两个标量求和。
只有当我们假设 bar()
和 baz()
的执行顺序可预测时,此代码的输出才能满足要求 #1。这在 NumPy 中不是问题,NumPy 始终按照 Python 解释器定义的顺序评估代码。然而,在 JAX 中,这更成问题:为了高效执行,我们希望 JIT 编译器可以自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,当在多设备环境中执行时,每个进程都需要同步全局状态,这将阻碍执行效率。
显式随机状态#
为了避免这些问题,JAX 避免了隐式全局随机状态,而是通过随机 key
显式跟踪状态
from jax import random
key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]
注意
本节使用由 jax.random.key()
生成的新式类型化 PRNG 键,而不是由 jax.random.PRNGKey()
生成的旧式原始 PRNG 键。有关详细信息,请参阅 JEP 9263:类型化键 & 可插拔 RNG。
键是一个数组,具有与正在使用的特定 PRNG 实现相对应的特殊 dtype;在默认实现中,每个键都由一对 uint32
值支持。
键实际上是 NumPy 的隐藏状态对象的替代品,但我们将其显式传递给 jax.random()
函数。重要的是,随机函数会消耗键,但不会修改它:将相同的键对象馈送到随机函数始终会导致生成相同的样本。
print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616
即使使用不同的 random
API 重复使用相同的键也可能导致相关的输出,这通常是不希望的。
经验法则是:永远不要重复使用键(除非您想要相同的输出)。重复使用相同的状态会导致悲伤和单调,剥夺最终用户的赋予生命的混乱。
JAX 使用现代 Threefry 基于计数器的 PRNG,它是可拆分的。也就是说,它的设计允许我们将 PRNG 状态分支到新的 PRNG 中,以用于并行随机生成。为了生成不同且独立的样本,您必须在将键传递给随机函数之前显式 split()
键
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: 0.6057640314102173
draw 1: -0.21089035272598267
draw 2: -0.3948981463909149
(此处不需要调用 del
,但我们这样做是为了强调一旦消耗了键就不应重复使用。)
jax.random.split()
是一个确定性函数,它将一个 key
转换为多个独立的(在伪随机性意义上)键。我们将其中一个输出保留为 new_key
,并且可以安全地使用唯一的额外键(称为 subkey
)作为随机函数的输入,然后永远丢弃它。如果您想从正态分布中获取另一个样本,您将再次拆分 key
,依此类推:关键是您永远不要使用相同的键两次。
调用 split(key)
的哪个输出为 key
,哪个为 subkey
并不重要。它们都是具有同等地位的独立键。键/子键命名约定是一种典型的使用模式,有助于跟踪键的消耗方式:子键注定要被随机函数立即消耗,而键则保留以稍后生成更多随机性。
通常,上面的示例会简洁地写成
key, subkey = random.split(key)
它会自动丢弃旧键。值得注意的是,split()
可以创建您需要的任意数量的键,而不仅仅是 2 个
key, *forty_two_subkeys = random.split(key, num=43)
缺乏顺序等价性#
NumPy 和 JAX 的随机模块之间的另一个区别与上面提到的顺序等价性保证有关。
与 NumPy 中一样,JAX 的随机模块也允许采样数字向量。但是,JAX 不提供顺序等价性保证,因为这样做会干扰 SIMD 硬件上的向量化(上述要求 #3)。
在下面的示例中,使用三个子键单独采样正态分布的 3 个值,得到的结果与使用单个键并指定 shape=(3,)
的结果不同
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [0.07592554 0.60576403 0.4323065 ]
all at once: [-0.02830462 0.46713185 0.29570296]
缺乏顺序等价性使我们能够更有效地编写代码;例如,我们可以使用 jax.vmap()
以向量化的方式计算相同的结果,而不是通过顺序循环生成上面的 sequence
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [0.07592554 0.60576403 0.4323065 ]
下一步#
有关 JAX 随机数的更多信息,请参阅 jax.random
模块的文档。如果您对 JAX 随机数生成器设计的详细信息感兴趣,请参阅 JAX PRNG 设计。