伪随机数#

如果所有因糟糕的 rand() 函数而导致结果存疑的科学论文从图书馆书架上消失,每个书架上都会出现一个拳头大小的空缺。 - Numerical Recipes

本节我们将重点介绍 jax.random 和伪随机数生成(PRNG);即,通过算法生成一系列数字,这些数字的属性近似于从适当分布中采样的随机数序列的属性。

PRNG 生成的序列并非真正的随机,因为它们实际上由其初始值(通常称为 seed)确定,并且随机采样的每一步都是从一个样本传递到下一个样本的某个 state 的确定性函数。

伪随机数生成是任何机器学习或科学计算框架的重要组成部分。通常,JAX 力求与 NumPy 兼容,但伪随机数生成是一个显著的例外。

为了更好地理解 JAX 和 NumPy 在随机数生成方法上的区别,我们将在本节中讨论这两种方法。

NumPy 中的随机数#

NumPy 的 numpy.random 模块原生支持伪随机数生成。在 NumPy 中,伪随机数生成基于一个全局 state,该状态可以使用 numpy.random.seed() 设置为确定性的初始条件。

import numpy as np
np.random.seed(0)

重复调用 NumPy 的有状态伪随机数生成器(PRNG)会改变全局状态并生成一串伪随机数。

print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439

在底层,NumPy 使用 Mersenne Twister PRNG 来驱动其伪随机函数。该 PRNG 的周期为 \(2^{19937}-1\),在任何时候都可以用 624 个 32 位无符号整数和一个表示已使用了多少“熵”的位置来描述。

您可以使用以下命令检查状态内容。

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

state 会被每次对随机函数的调用更新

np.random.seed(0)
print_truncated_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

NumPy 允许您通过单次函数调用采样单个数字或整个数字向量。例如,您可以从均匀分布中采样一个包含 3 个标量的向量,方法是:

np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135  0.71518937 0.60276338]

NumPy 提供序列等价保证,这意味着连续单独采样 N 个数字或采样一个包含 N 个数字的向量会产生相同的伪随机序列。

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]

JAX 中的随机数#

JAX 的随机数生成与 NumPy 的在重要方面有所不同,因为 NumPy 的 PRNG 设计使其难以同时保证许多期望的特性。具体而言,在 JAX 中,我们希望 PRNG 生成是:

  1. 可复现的,

  2. 可并行化的,

  3. 可向量化的。

我们将在下文讨论原因。首先,我们将重点关注基于全局状态的 PRNG 设计的影响。考虑以下代码:

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())
1.9791922366721637

函数 foo 将从均匀分布中采样的两个标量求和。

仅当我们假设 bar()baz() 的执行顺序是可预测的,此代码的输出才能满足要求 #1。这在 NumPy 中不是问题,NumPy 总是按照 Python 解释器定义的顺序评估代码。然而,在 JAX 中,这更具问题:为了高效执行,我们希望 JIT 编译器能够自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,每个进程同步全局状态的需要会阻碍执行效率。

显式随机状态#

为避免这些问题,JAX 避免隐式全局随机状态,而是通过随机 key 显式跟踪状态。

from jax import random

key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]

注意

本节使用 jax.random.key() 生成的新式类型化 PRNG 密钥,而不是 jax.random.PRNGKey() 生成的旧式原始 PRNG 密钥。有关详细信息,请参阅 JEP 9263:类型化密钥与可插拔 RNGs

密钥是一个具有特殊 `dtype` 的数组,对应于所使用的特定 PRNG 实现;在默认实现中,每个密钥由一对 uint32 值支持。

该密钥实际上是 NumPy 隐藏状态对象的替代品,但我们将其显式传递给 jax.random() 函数。重要的是,随机函数会消耗密钥,但不会修改它:将相同的密钥对象传递给随机函数将始终生成相同的样本。

print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616

重复使用相同的密钥,即使使用不同的 random API,也可能导致相关的输出,这通常是不希望的。

经验法则是:绝不要重复使用密钥(除非你想要相同的输出)。重复使用相同的状态会导致悲伤单调,剥夺最终用户的活力四射的混沌

JAX 使用现代的 Threefry 基于计数器的可拆分 PRNG。也就是说,其设计允许我们将 PRNG 状态分裂成新的 PRNG,用于并行随机生成。为了生成不同且独立的样本,您必须在将密钥传递给随机函数之前显式地 split() 密钥:

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: 0.6057640314102173
draw 1: -0.21089035272598267
draw 2: -0.3948981463909149

(这里不需要调用 del,但我们这样做是为了强调密钥一旦使用后不应重复使用。)

jax.random.split() 是一个确定性函数,它将一个 key 转换为几个独立的(在伪随机性意义上)密钥。我们将其中一个输出保留为 new_key,并且可以安全地将唯一的额外密钥(称为 subkey)用作随机函数的输入,然后永远丢弃它。如果您想从正态分布中获得另一个样本,您将再次分裂 key,依此类推:关键点是您绝不要两次使用同一个密钥。

我们将 split(key) 的输出中的哪一部分称为 key,哪一部分称为 subkey 并不重要。它们都是具有同等地位的独立密钥。key/subkey 命名约定是一种典型的使用模式,有助于跟踪密钥的消耗方式:subkey 旨在立即被随机函数使用,而 key 则被保留用于后续生成更多随机性。

通常,上面的示例可以简洁地写成:

key, subkey = random.split(key)

这会自动丢弃旧密钥。值得注意的是,split() 可以创建你需要的任意数量的密钥,而不仅仅是 2 个。

key, *forty_two_subkeys = random.split(key, num=43)

缺乏序列等价性#

NumPy 和 JAX 随机模块的另一个区别与上述提到的序列等价保证有关。

与 NumPy 中一样,JAX 的随机模块也允许采样数字向量。但是,JAX 不提供序列等价保证,因为这样做会干扰 SIMD 硬件上的向量化(上述要求 #3)。

在下面的示例中,使用三个子密钥单独从正态分布中采样 3 个值与使用单个密钥并指定 shape=(3,) 会得到不同的结果:

key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [0.07592554 0.60576403 0.4323065 ]
all at once:  [-0.02830462  0.46713185  0.29570296]

缺乏序列等价性使我们能够更高效地编写代码;例如,我们可以使用 jax.vmap() 以向量化方式计算相同的结果,而不是通过顺序循环生成上述 sequence

import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [0.07592554 0.60576403 0.4323065 ]

后续步骤#

有关 JAX 随机数的更多信息,请参阅 jax.random 模块的文档。如果您对 JAX 随机数生成器设计的细节感兴趣,请参阅 JAX PRNG 设计