jax.random 模块#
伪随机数生成的实用工具。
包提供了一系列用于确定性生成伪随机数序列的例程。jax.random
基本用法#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG 密钥#
与 NumPy 和 SciPy 用户可能习惯的有状态伪随机数生成器 (PRNG) 不同,JAX 的随机函数都需要将一个显式的 PRNG 状态作为第一个参数传递。随机状态由一种特殊的数组元素类型描述,我们称之为**密钥**,通常由 函数生成。jax.random.key()
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
这个密钥随后可以用于 JAX 的任何随机数生成例程。
>>> random.uniform(key)
Array(0.947667, dtype=float32)
请注意,使用密钥不会修改它,因此重复使用相同的密钥将导致相同的结果。
>>> random.uniform(key)
Array(0.947667, dtype=float32)
如果您需要新的随机数,可以使用 生成新的子密钥。jax.random.split()
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.00729382, dtype=float32)
注意
类型化的密钥数组,如上面的 key<fry> 等元素类型,是在 JAX v0.4.16 中引入的。在此之前,密钥通常用 uint32 数组表示,其最后一个维度表示密钥的位级别表示。
两种形式的密钥数组仍然可以使用 模块创建和使用。新式类型化密钥数组使用 jax.random 创建。旧式 jax.random.key()uint32 密钥数组使用 创建。jax.random.PRNGKey()
要在这两者之间进行转换,请使用 和 jax.random.key_data()。当与 JAX 外部系统交互时(例如,将数组导出为可序列化格式),或在将密钥传递给假定旧式格式的基于 JAX 的库时,可能需要旧式密钥格式。jax.random.wrap_key_data()
否则,推荐使用类型化密钥。与类型化密钥相比,旧式密钥的缺点包括:
它们有一个额外的尾随维度。
它们具有数值 dtype(
uint32),允许执行通常不应在密钥上执行的操作,例如整数算术。它们不携带有关 RNG 实现的信息。当旧式密钥传递给
函数时,全局配置设置将决定 RNG 实现(参见下面的“高级 RNG 配置”)。jax.random
要了解有关此次升级以及密钥类型设计的更多信息,请参阅 JEP 9263。
高级#
设计与背景#
简而言之:JAX PRNG = Threefry counter PRNG + 一个函数式、面向数组的 分裂模型。
有关更多详细信息,请参阅 docs/jep/263-prng.md。
总结来说,JAX PRNG 旨在实现以下目标(包括其他要求):
确保可重现性;
良好的并行性,包括向量化(生成数组值)和多副本、多核计算。特别是,它不应该在随机函数调用之间使用排序约束。
高级 RNG 配置#
JAX 提供了多种 PRNG 实现。可以通过 的可选 jax.random.keyimpl 关键字参数选择特定的实现。当 构造函数未传递 keyimpl 选项时,实现由全局 配置标志决定。可用实现的字符串名称包括:jax_default_prng_impl
"threefry2x32"(**默认**):一个基于 Threefry 哈希函数变体的计数器 PRNG,如 Salmon 等人 2011 年的这篇论文 所述。"rbg"和"unsafe_rbg"(**实验性**):基于 XLA 的随机位生成器 (RBG) 算法 构建的 PRNG。"rbg"使用 XLA RBG 进行随机数生成,而对于密钥派生(如和jax.random.split),它使用与jax.random.fold_in相同的方法。"threefry2x32""unsafe_rbg"同时用于生成和密钥派生。:
通过这些实验性方案生成的随机数尚未经过经验随机性测试(例如 BigCrush)。
中的密钥派生也尚未经过经验测试。该名称强调“不安全”是因为密钥派生质量和生成质量尚不明确。"unsafe_rbg"此外,
和"rbg""unsafe_rbg"在下表现异常。当沿着密钥批次对随机函数进行 vmap 时,其输出值可能与对相同密钥进行真实 map 的输出值不同。相反,在jax.vmap下,整个输出随机数批次仅从输入密钥批次中的第一个密钥生成。例如,如果vmapkeys是一个包含 8 个密钥的向量,则等于jax.vmap(jax.random.normal)(keys)。这种特殊性反映了对 XLA RBG 有限的批处理支持的变通方法。jax.random.normal(keys[0], shape=(8,))
使用默认 RNG 以外的其他 RNG 的原因包括:
在 TPU 上编译可能很慢。
在 TPU 上执行相对较慢。
自动分区
为了让 能够有效地自动分区生成分片随机数数组(或密钥数组)的函数,所有 PRNG 实现都依赖于额外的标志。jax.jit
对于
和"threefry2x32"密钥派生,设置"rbg"jax_threefry_partitionable=True。截至 JAX v.0.5.0,这是默认设置。对于
和"unsafe_rbg"随机生成,请设置 XLA 标志"rbg"。--xla_tpu_spmd_rng_bit_generator_unsafe=1
可以使用 环境变量设置 XLA 标志,例如 XLA_FLAGS。XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
有关 的更多信息,请参阅 jax-ml/jax#18480。jax_threefry_partitionable
概述
属性 |
Threefry |
Threefry* |
rbg |
unsafe_rbg |
rbg** |
unsafe_rbg** |
|---|---|---|---|---|---|---|
TPU 上最快 |
✅ |
✅ |
✅ |
✅ |
||
高效分片 (w/ pjit) |
✅ |
✅ |
✅ |
|||
分片之间相同 |
✅ |
✅ |
✅ |
✅ |
||
CPU/GPU/TPU 之间相同 |
✅ |
✅ |
||||
精确的 |
✅ |
✅ |
(*): 设置了 (JAX v0.5.0 起为默认设置)jax_threefry_partitionable=1
(**): 设置了 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
API 参考#
密钥创建与操作#
|
根据整数种子创建伪随机数生成器 (PRNG) 密钥。 |
|
恢复 PRNG 密钥数组底层的密钥数据位。 |
|
将密钥数据位数组包装成 PRNG 密钥数组。 |
|
将数据折叠到 PRNG 密钥中以形成新的 PRNG 密钥。 |
|
将 PRNG 密钥分割成 num 个新密钥,方法是添加一个前导轴。 |
|
克隆密钥以供重用。 |
|
根据整数种子创建旧式 PRNG 密钥。 |
随机采样器#
|
从单位 Lp 球体中均匀采样。 |
|
以给定形状和均值采样伯努利随机值。 |
|
以给定形状和浮点 dtype 采样 Beta 随机值。 |
|
以给定形状和浮点 dtype 采样二项式随机值。 |
|
以无符号整数的形式采样均匀比特。 |
|
从分类分布中采样随机值。 |
|
以给定形状和浮点 dtype 采样柯西随机值。 |
|
以给定形状和浮点 dtype 采样卡方随机值。 |
|
从给定数组中生成随机样本。 |
|
以给定形状和浮点 dtype 采样 Dirichlet 随机值。 |
|
从双边麦克斯韦分布中采样。 |
|
以给定形状和浮点 dtype 采样指数随机值。 |
|
以给定形状和浮点 dtype 采样 F 分布随机值。 |
|
以给定形状和浮点 dtype 采样 Gamma 随机值。 |
|
从广义正态分布中采样。 |
|
以给定形状和浮点 dtype 采样几何随机值。 |
|
以给定形状和浮点 dtype 采样 Gumbel 随机值。 |
|
以给定形状和浮点 dtype 采样拉普拉斯随机值。 |
|
以给定形状和浮点 dtype 采样对数伽马随机值。 |
|
以给定形状和浮点 dtype 采样 logistic 随机值。 |
|
以给定形状和浮点 dtype 采样对数正态随机值。 |
|
从单边麦克斯韦分布中采样。 |
|
从多项分布中采样。 |
|
以给定的均值和协方差采样多元正态随机值。 |
|
以给定的形状和浮点 dtype 采样标准正态随机值。 |
|
从正交群 O(n) 中均匀采样。 |
|
以给定形状和浮点 dtype 采样帕累托随机值。 |
|
返回随机排列的数组或范围。 |
|
以给定的形状和整数 dtype 采样泊松随机值。 |
|
从 Rademacher 分布中采样。 |
|
以给定形状/数据类型采样 [minval, maxval) 范围内的均匀随机值。 |
|
以给定的形状和浮点 dtype 采样瑞利随机值。 |
|
以给定的形状和浮点 dtype 采样学生 t 随机值。 |
|
以给定形状和浮点 dtype 采样三角形随机值。 |
|
以给定的形状和 dtype 采样截断的标准正态随机值。 |
|
以给定形状/数据类型采样 [minval, maxval) 范围内的均匀随机值。 |
|
以给定的形状和浮点 dtype 采样 Wald 随机值。 |
|
从 Weibull 分布中采样。 |