jax.random
模块#
用于伪随机数生成的实用程序。
jax.random
包提供了一些例程,用于确定性地生成伪随机数序列。
基本用法#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG 键#
与 NumPy 和 SciPy 用户可能习惯的*有状态*伪随机数生成器(PRNG)不同,JAX 的所有随机函数都要求将一个显式的 PRNG 状态作为第一个参数传递。该随机状态由一种特殊的数组元素类型描述,我们称之为**键**,通常由 jax.random.key()
函数生成。
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
这个键然后可以用于 JAX 的任何随机数生成例程中。
>>> random.uniform(key)
Array(0.947667, dtype=float32)
请注意,使用键不会修改它,因此重复使用同一个键将导致相同的结果。
>>> random.uniform(key)
Array(0.947667, dtype=float32)
如果您需要一个新的随机数,您可以使用 jax.random.split()
来生成新的子键。
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.00729382, dtype=float32)
注意
类型化的键数组,其元素类型如上所示的 key<fry>
,是在 JAX v0.4.16 中引入的。在此之前,键通常表示为 uint32
数组,其最后一个维度表示键的位级表示。
这两种形式的键数组仍然可以使用 jax.random
模块创建和使用。新式类型化键数组通过 jax.random.key()
创建。旧版 uint32
键数组通过 jax.random.PRNGKey()
创建。
要在两者之间进行转换,请使用 jax.random.key_data()
和 jax.random.wrap_key_data()
。在与 JAX 外部系统(例如将数组导出为可序列化格式)交互时,或者在将键传递给假定使用旧版格式的基于 JAX 的库时,可能需要旧版键格式。
否则,推荐使用类型化的键。与类型化的键相比,旧版键的注意事项包括:
它们有一个额外的尾随维度。
它们具有数值数据类型(
uint32
),这允许执行通常不适用于键的操作,例如整数算术。它们不包含关于 RNG 实现的信息。当旧版键传递给
jax.random
函数时,一个全局配置设置会决定 RNG 实现(参见下面的“高级 RNG 配置”)。
要了解更多关于此升级和键类型设计的信息,请参见 JEP 9263。
高级#
设计与背景#
TLDR:JAX PRNG = Threefry 计数器 PRNG + 函数式数组导向的 拆分模型
参见 docs/jep/263-prng.md 获取更多详情。
总而言之,除了其他要求外,JAX PRNG 旨在:
确保可复现性,
良好并行化,无论是在向量化(生成数组值)还是多副本、多核计算方面。特别是,它不应在随机函数调用之间使用序列化约束。
高级 RNG 配置#
JAX 提供了几种 PRNG 实现。可以通过 jax.random.key
的可选 impl
关键字参数来选择一个特定的实现。当没有向 key
构造函数传递 impl
选项时,实现由全局 jax_default_prng_impl
配置标志决定。可用实现的字符串名称包括:
"threefry2x32"
(默认):一个基于计数器的 PRNG,基于 Threefry 哈希函数的一种变体,如 Salmon 等人于 2011 年的这篇论文中所述。"rbg"
和"unsafe_rbg"
(实验性):基于 XLA 的随机位生成器(RBG)算法构建的 PRNG。"rbg"
使用 XLA RBG 进行随机数生成,而对于键派生(如在jax.random.split
和jax.random.fold_in
中),它使用与"threefry2x32"
相同的方法。"unsafe_rbg"
在生成和键派生中都使用 XLA RBG。
由这些实验性方案生成的随机数尚未经过经验随机性测试(例如 BigCrush)。
"unsafe_rbg"
中的键派生也未经经验测试。该名称强调“不安全”是因为键派生质量和生成质量尚未充分理解。此外,
"rbg"
和"unsafe_rbg"
在jax.vmap
下表现异常。当对一批键执行随机函数的 vmap 操作时,其输出值可能与对相同键的真实映射不同。相反,在vmap
下,整个批次的输出随机数仅从输入键批次中的第一个键生成。例如,如果keys
是一个包含 8 个键的向量,则jax.vmap(jax.random.normal)(keys)
等于jax.random.normal(keys[0], shape=(8,))
。这种特殊性反映了对 XLA RBG 有限批处理支持的变通方案。
使用替代默认 RNG 的原因包括:
它可能对 TPU 编译缓慢。
它在 TPU 上执行相对较慢。
自动分区
为了使 jax.jit
能够有效地自动分区生成分片随机数数组(或键数组)的函数,所有 PRNG 实现都需要额外的标志:
对于
"threefry2x32"
和"rbg"
键派生,请设置jax_threefry_partitionable=True
。对于
"unsafe_rbg"
和"rbg"
随机生成,请设置 XLA 标志--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
XLA 标志可以使用 XLA_FLAGS
环境变量设置,例如 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
更多关于 jax_threefry_partitionable
的信息,请参见 https://jax.net.cn/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
概述
属性 |
Threefry |
Threefry* |
rbg |
unsafe_rbg |
rbg** |
unsafe_rbg** |
---|---|---|---|---|---|---|
在 TPU 上最快 |
✅ |
✅ |
✅ |
✅ |
||
高效可分片(带 pjit) |
✅ |
✅ |
✅ |
|||
跨分片一致 |
✅ |
✅ |
✅ |
✅ |
||
跨 CPU/GPU/TPU 一致 |
✅ |
✅ |
||||
对键的精确 |
✅ |
✅ |
(*): 设置了 jax_threefry_partitionable=1
(**): 设置了 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
API 参考#
键的创建和操作#
|
根据整数种子创建一个伪随机数生成器(PRNG)键。 |
|
恢复 PRNG 键数组底层键数据的位。 |
|
将键数据位的数组封装成一个 PRNG 键数组。 |
|
将数据折叠到 PRNG 键中以形成新的 PRNG 键。 |
|
通过添加一个前导轴,将一个 PRNG 键拆分成 num 个新键。 |
|
克隆一个键以供重用 |
|
根据整数种子创建一个旧版 PRNG 键。 |
随机采样器#
|
从单位 Lp 球中均匀采样。 |
|
以给定形状和均值采样伯努利随机值。 |
|
以给定形状和浮点数据类型采样 Beta 随机值。 |
|
以给定形状和浮点数据类型采样二项式随机值。 |
|
以无符号整数形式采样均匀位。 |
|
从分类分布中采样随机值。 |
|
以给定形状和浮点数据类型采样柯西随机值。 |
|
以给定形状和浮点数据类型采样卡方随机值。 |
|
从给定数组生成随机样本。 |
|
以给定形状和浮点数据类型采样狄利克雷随机值。 |
|
从双侧麦克斯韦分布中采样。 |
|
以给定形状和浮点数据类型采样指数随机值。 |
|
以给定形状和浮点数据类型采样 F-分布随机值。 |
|
以给定形状和浮点数据类型采样 Gamma 随机值。 |
|
从广义正态分布中采样。 |
|
以给定形状和浮点数据类型采样几何随机值。 |
|
以给定形状和浮点数据类型采样 Gumbel 随机值。 |
|
以给定形状和浮点数据类型采样拉普拉斯随机值。 |
|
以给定形状和浮点数据类型采样对数伽马随机值。 |
|
以给定形状和浮点数据类型采样逻辑斯谛随机值。 |
|
以给定形状和浮点数据类型采样对数正态随机值。 |
|
从单侧麦克斯韦分布中采样。 |
|
从多项式分布中采样。 |
|
以给定均值和协方差采样多元正态随机值。 |
|
以给定形状和浮点数据类型采样标准正态随机值。 |
|
从正交群 O(n) 中均匀采样。 |
|
以给定形状和浮点数据类型采样帕累托随机值。 |
|
返回一个随机排列的数组或范围。 |
|
以给定形状和整数数据类型采样泊松随机值。 |
|
从拉德马赫分布中采样。 |
|
以给定形状/数据类型采样 [minval, maxval) 范围内的均匀随机值。 |
|
以给定形状和浮点数据类型采样瑞利随机值。 |
|
以给定形状和浮点数据类型采样学生 t 随机值。 |
|
以给定形状和浮点数据类型采样三角随机值。 |
|
以给定形状和数据类型采样截断标准正态随机值。 |
|
以给定形状/数据类型采样 [minval, maxval) 范围内的均匀随机值。 |
|
以给定形状和浮点数据类型采样 Wald 随机值。 |
|
从威布尔分布中采样。 |