jax.random
模块#
用于伪随机数生成的实用工具。
jax.random
包提供了一系列用于确定性生成伪随机数序列的例程。
基本用法#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG 密钥#
与 NumPy 和 SciPy 用户可能习惯的有状态伪随机数生成器 (PRNG) 不同,JAX 随机函数都要求显式 PRNG 状态作为第一个参数传递。随机状态由我们称之为密钥的特殊数组元素类型描述,通常由 jax.random.key()
函数生成
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
然后,此密钥可以用于任何 JAX 的随机数生成例程中
>>> random.uniform(key)
Array(0.947667, dtype=float32)
请注意,使用密钥不会修改它,因此重用相同的密钥将导致相同的结果
>>> random.uniform(key)
Array(0.947667, dtype=float32)
如果您需要新的随机数,可以使用 jax.random.split()
生成新的子密钥
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.00729382, dtype=float32)
注意
类型化密钥数组,具有诸如上面的 key<fry>
之类的元素类型,是在 JAX v0.4.16 中引入的。在此之前,密钥通常以 uint32
数组表示,其最后一个维度表示密钥的位级表示。
两种形式的密钥数组仍然可以使用 jax.random
模块创建和使用。新的类型化密钥数组使用 jax.random.key()
创建。旧式的 uint32
密钥数组使用 jax.random.PRNGKey()
创建。
要在两者之间转换,请使用 jax.random.key_data()
和 jax.random.wrap_key_data()
。当与 JAX 之外的系统(例如,将数组导出为可序列化格式)接口,或将密钥传递给假定旧格式的基于 JAX 的库时,可能需要旧式密钥格式。
否则,建议使用类型化密钥。相对于类型化密钥,旧式密钥的注意事项包括
它们有一个额外的尾随维度。
它们具有数值数据类型 (
uint32
),允许进行通常不应在密钥上执行的操作,例如整数算术。它们不携带有关 RNG 实现的信息。当旧式密钥传递给
jax.random
函数时,全局配置设置确定 RNG 实现(请参阅下面的“高级 RNG 配置”)。
要了解有关此升级和密钥类型设计的更多信息,请参阅 JEP 9263。
高级#
设计和背景#
TLDR: JAX PRNG = Threefry 计数器 PRNG + 一个函数式、面向数组的 拆分模型
有关更多详细信息,请参阅 docs/jep/263-prng.md。
总而言之,除其他要求外,JAX PRNG 旨在
确保可重复性,
在向量化(生成数组值)和多副本、多核计算方面都具有良好的并行化能力。特别是,它不应在随机函数调用之间使用排序约束。
高级 RNG 配置#
JAX 提供了几种 PRNG 实现。可以使用 jax.random.key
的可选 impl
关键字参数选择特定的实现。当没有 impl
选项传递给 key
构造函数时,实现由全局 jax_default_prng_impl
配置标志确定。可用实现的字符串名称为
"threefry2x32"
(默认): 基于 Threefry 哈希函数变体的基于计数器的 PRNG,如 Salmon 等人,2011 年的这篇论文中所述。"rbg"
和"unsafe_rbg"
(实验性): 构建于 XLA 的随机位生成器 (RBG) 算法之上的 PRNG。"rbg"
使用 XLA RBG 进行随机数生成,而对于密钥派生(如jax.random.split
和jax.random.fold_in
中),它使用与"threefry2x32"
相同的方法。"unsafe_rbg"
将 XLA RBG 用于生成和密钥派生。
这些实验性方案生成的随机数尚未经过经验随机性测试(例如 BigCrush)。
"unsafe_rbg"
中的密钥派生也未经经验测试。名称强调“unsafe”,因为密钥派生质量和生成质量尚不清楚。此外,
"rbg"
和"unsafe_rbg"
在jax.vmap
下的行为异常。当 vmapping 一个随机函数在一批密钥上时,它的输出值可能与其在相同密钥上的真实映射不同。相反,在vmap
下,整批输出随机数仅从输入密钥批次中的第一个密钥生成。例如,如果keys
是 8 个密钥的向量,则jax.vmap(jax.random.normal)(keys)
等于jax.random.normal(keys[0], shape=(8,))
。这种特殊性反映了 XLA RBG 有限批处理支持的解决方法。
使用默认 RNG 替代方案的原因包括:
对于 TPU 编译可能很慢。
在 TPU 上执行相对较慢。
自动分区
为了使 jax.jit
能够有效地自动分区生成分片随机数数组(或密钥数组)的函数,所有 PRNG 实现都需要额外的标志
对于
"threefry2x32"
和"rbg"
密钥派生,设置jax_threefry_partitionable=True
。对于
"unsafe_rbg"
和"rbg"
随机生成”,设置 XLA 标志--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
可以使用 XLA_FLAGS
环境变量设置 XLA 标志,例如 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
有关 jax_threefry_partitionable
的更多信息,请参阅 https://jax.net.cn/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
总结
属性 |
Threefry |
Threefry* |
rbg |
unsafe_rbg |
rbg** |
unsafe_rbg** |
---|---|---|---|---|---|---|
TPU 上最快 |
✅ |
✅ |
✅ |
✅ |
||
高效可分片 (w/ pjit) |
✅ |
✅ |
✅ |
|||
跨分片相同 |
✅ |
✅ |
✅ |
✅ |
||
跨 CPU/GPU/TPU 相同 |
✅ |
✅ |
||||
密钥上的精确 |
✅ |
✅ |
(*): 设置 jax_threefry_partitionable=1
(**): 设置 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
API 参考#
密钥创建与操作#
|
创建一个给定整数种子的伪随机数生成器 (PRNG) 密钥。 |
|
恢复 PRNG 密钥数组的基础密钥数据位。 |
|
将密钥数据位数组包装到 PRNG 密钥数组中。 |
|
将数据折叠到 PRNG 密钥中以形成新的 PRNG 密钥。 |
|
通过添加前导轴将一个 PRNG 密钥拆分为 num 个新密钥。 |
|
克隆密钥以供重用 |
|
创建一个给定整数种子的旧式 PRNG 密钥。 |
随机采样器#
|
从单位 Lp 球中均匀采样。 |
|
采样具有给定形状和均值的伯努利随机值。 |
|
采样具有给定形状和浮点数据类型的 Beta 随机值。 |
|
采样具有给定形状和浮点数据类型的二项式随机值。 |
|
以无符号整数的形式采样均匀位。 |
|
从分类分布中采样随机值。 |
|
采样具有给定形状和浮点数据类型的柯西随机值。 |
|
采样具有给定形状和浮点数据类型的卡方随机值。 |
|
从给定数组生成随机样本。 |
|
采样具有给定形状和浮点数据类型的狄利克雷随机值。 |
|
从双边麦克斯韦分布中采样。 |
|
采样具有给定形状和浮点数据类型的指数随机值。 |
|
采样具有给定形状和浮点数据类型的 F 分布随机值。 |
|
采样具有给定形状和浮点数据类型的 Gamma 随机值。 |
|
从广义正态分布中采样。 |
|
采样具有给定形状和浮点数据类型的几何随机值。 |
|
采样具有给定形状和浮点数据类型的耿贝尔随机值。 |
|
采样具有给定形状和浮点数据类型的拉普拉斯随机值。 |
|
采样具有给定形状和浮点数据类型的对数伽玛随机值。 |
|
采样具有给定形状和浮点数据类型的逻辑随机值。 |
|
采样具有给定形状和浮点数据类型的对数正态随机值。 |
|
从单边麦克斯韦分布中采样。 |
|
从多项分布中采样。 |
|
采样具有给定均值和协方差的多元正态随机值。 |
|
采样具有给定形状和浮点数据类型的标准正态随机值。 |
|
从正交群 O(n) 中均匀采样。 |
|
采样具有给定形状和浮点数据类型的帕累托随机值。 |
|
返回一个随机排列的数组或范围。 |
|
采样具有给定形状和整数数据类型的泊松随机值。 |
|
从拉德马赫分布中采样。 |
|
采样 [minval, maxval) 中具有给定形状/数据类型的均匀随机值。 |
|
采样具有给定形状和浮点数据类型的瑞利随机值。 |
|
采样具有给定形状和浮点数据类型的学生 t 随机值。 |
|
采样具有给定形状和浮点数据类型的三角随机值。 |
|
采样具有给定形状和数据类型的截断标准正态随机值。 |
|
采样 [minval, maxval) 中具有给定形状/数据类型的均匀随机值。 |
|
采样具有给定形状和浮点数据类型的瓦尔德随机值。 |
|
从 Weibull 分布中采样。 |