JEP 28845:JAX 中的有状态随机数#
@jakevdp, 2025 年 11 月
本文档探讨了在 JAX 中添加一个可选的有状态伪随机数生成器 (PRNG);这旨在与 伪随机数 中描述的经典函数式 PRNG 一起使用,以便在有状态处理更方便的情况下使用。
背景#
JAX 一直要求用户作为其函数式编程范式的一部分显式管理随机状态(有关背景信息,请参阅 JAX PRNG 设计)。虽然此设计初衷良好,但对于习惯了有状态伪随机数 API 的新用户来说,这是一个经常遇到的 陷阱。
随着 JAX 最近引入了有限范围的 可变引用 (mutable refs),现在可以在 JAX 中实现一个有状态的 PRNG,它在保留现有函数式 PRNG 大部分性能优势的同时,为熟悉 NumPy、PyTorch 和其他数值计算库的用户提供了更自然的 API。
本 JAX 增强提案(或 JEP)提议在 jax.experimental.random 中引入有状态 PRNG API,目标是最终将其纳入 jax.random 本身。
API 设计#
为了与更广泛的 Python 数值计算社区中形成的最佳实践保持一致,我们提议使有状态 PRNG API 与 NumPy 最新的 PRNG API 迭代保持一致(见 numpy.random.Generator),该 API 通常使用 numpy.random.default_rng() 函数创建。拟议实现的完整草案可以在 #28845 中找到,但此处我们总结了该实现的主要功能。
有状态 PRNG 生成器代码的简化版本如下所示(函数和参数名称遵循 numpy.random API):
def stateful_rng(seed: ArrayLike) -> StatefulPRNG:
"""Create a stateful PRNG Generator given an integer seed."""
return StatefulPRNG(jax.random.key(seed), jax.new_ref(0))
@tree_util.register_dataclass
@dataclass(frozen=True)
class StatefulPRNG:
"""Stateful PRNG Generator class."""
base_key: jax.Array
counter: jax.core.Ref
def key(self) -> jax.Array:
"""Generate a new jax PRNG key"""
key = jax.random.fold_in(self.base_key, self.counter[...])
jax.ref.addupdate(self.counter, ..., 1) # increment counter
return key
def random(self, size: Sequence[int], dtype: DType = float):
"""Return random floats in the half-open interval [0, 1)"""
return random.uniform(self.key(), shape=size, dtype=dtype)
# uniform(), normal(), integers(), and others implemented similarly.
随着此实现暴露在 jax.experimental.random 命名空间中,其用法与 numpy.random.default_rng() 几乎相同。
>>> from jax.experimental.random import stateful_rng
>>> rng = stateful_rng(1701)
>>> rng.random((5,))
Array([0.09609699, 0.26730824, 0.5619041 , 0.24421775, 0.7715055 ], dtype=float32)
>>> rng.random((5,)) # state is updated -> new random draws!
Array([0.8131045 , 0.33873856, 0.88808906, 0.96005905, 0.7616446 ], dtype=float32)
>>> import numpy as np
>>> rng = np.random.default_rng(1701)
>>> rng.random((5,))
array([0.4020733 , 0.30563311, 0.67668051, 0.15821208, 0.79247763])
>>> rng.random((5,))
array([0.09419469, 0.36753944, 0.06388928, 0.96431608, 0.35200998])
由于 jax.experimental.random.StatefulPRNG 中的有状态性是通过可变引用来跟踪的,因此即使生成器被用于像 jax.jit() 这样通常需要纯函数式语义的变换中,随机状态也会正确更新。
与 vmap 和 shard_map 的交互#
拟议的有状态 RNG 设计基于引用,因此在 vmap 和 shard_map 下,它继承了引用的限制。例如,您不能直接在 vmap 变换后的函数中使用未映射的 rng。
rng = stateful_rng(0)
def f(x):
return x + rng.uniform()
jax.vmap(f)(jnp.arange(10))
Exception: performing an addupdate operation with vmapped value on an unbatched
array reference of type Ref{int32[]}. Move the array reference to be
an argument to the vmapped function?
因此,我们需要能够拆分生成器,以便将其传递给映射或分片代码。为此,我们向 StatefulPRNG 类添加了一个 split 方法,如下所示:
class StatefulPRNG:
...
def split(self, num: int | Sequence[int]) -> StatefulPRNG:
return StatefulPRNG(
base_key=jax.random.split(self.key(), num),
counter=jnp.zeros(num, dtype=int),
)
有了这个方法,有状态 rng 可以被显式拆分并传递给 vmap 变换后的函数。
rng = jax.experimental.random.stateful_rng(0)
def f(x, rng):
return x + rng.uniform()
result = jax.vmap(f)(jnp.arange(5), rng.split(5))
print(result) # [0.07174575 1.0163325 2.0435536 3.4391735 4.534091 ]
类似的方法也适用于分片计算,尽管 split 可能需要增加一个 sharding 参数。
这种拆分提出了一个问题:如果用户尝试直接从拆分后的生成器(如 rng.split(10).uniform())生成随机数,该怎么办?为此,我们遵循经典无状态 jax.random API 在接收批量键时的先例,并引发一个提示性错误。
统计学考量#
在拟议的设计中,随机状态通过一个基础键和一个每次生成键时都会递增的整数计数器来跟踪。我们选择这种方法而不是直接改变键本身,是为了避免迭代拆分(参见 INSERT_REF_HERE)的缺陷;特别是这意味着有状态生成器在循环回到零并重复初始键之前,将始终完全遍历 32 位或 64 位键空间。
优势#
这种方法的主要优势是熟悉度:许多用户熟悉 NumPy 及其有状态 RNG 实用程序。这将使他们能够更直接地开始使用 JAX,而无需克服不熟悉的函数式 PRNG API 的学习曲线。
这不仅影响 JAX 用户:为了方便,即使是 JAX 开发人员也倾向于在变换之外切换上下文并使用有状态的 NumPy API,因为在那些地方函数式 PRNG 是不必要的。这导致了 JAX 用户的困惑(例如,参见 这个 GitHub 讨论)。拥有一个 JAX 原生的有状态 API 将使在实时演示和书面教程中始终使用 JAX PRNG 变得更加方便。
函数式 PRNG 的另一个陷阱是意外重复使用键的可能性。不熟悉显式状态需求的用户可能会多次使用相同的键,从而无意中生成统计上相关的随机值(例如,参见 这个 StackOverflow 问题)。通过鼓励新 JAX 用户使用有状态 PRNG,我们可以避免这种无声的陷阱。
最后,该 API 提供了调用 rng.key() 以创建标准函数式 PRNG 键的能力,该键随后可以在典型的函数式模式中使用:这是在需要显式管理状态的情况下,向其过渡的简单入口。
局限性#
通过可变引用实现有状态 PRNG 键会带来一些固有的局限性,特别是:
序列依赖性限制了编译器: 使用此类键的程序在程序内部强加了固有的序列依赖性,这意味着编译器将无法自由地重新排序依赖于伪随机值的操作。这种情况下的陷阱是静默的:用户需要自行识别哪里会出现问题,并转而对预先生成的键或值序列使用批量执行模式。但请注意,当用户遵循 JAX 文档中的当前使用建议时,这种序列依赖性陷阱也存在:https://jax.net.cn/en/latest/jax.random.html#basic-usage。
序列依赖性限制了用户: 同样,正如编译器无法在不改变随机性的情况下重新排序操作一样,这种序列依赖性也意味着用户无法在不改变特定随机抽取结果的情况下轻松重构代码。一个潜在的例子:假设在神经网络内部使用了一个有状态 RNG,而用户决定用一个具有不同随机抽取结果的层来替换内部层:这将消耗一个键并影响随后所有层的随机抽取结果。
与 remat 不兼容: 由于可变引用依赖于 JAX 的效果系统,这些 API 在不支持效果的地方将无法使用。特别是,这意味着在 JAX 的当前实现中,有状态键将与 remat 不兼容,这可能会限制它们在神经网络实现中的有用性。这种情况下的陷阱是明显的:尝试在 remat 内部使用可变引用将导致显式错误。未来对 remat 的重新设计有可能消除这种不兼容性(关于此进展,请参阅 #33018)。
引用不能作为返回值: 可变引用不能存在于变换后的 JAX 函数的返回值中,而拟议的有状态 RNG 对象将继承此限制。这也是一个显式的限制:尝试从变换后的函数中返回 StatefulPRNG 将导致显式错误。
评估#
我们的判断是,有状态 PRNG API 的优势可能超过其局限性,我们目前应该在 jax.experimental.random 模块中引入一个新的实验性 stateful_rng() API。一旦我们对其有用性有了体会,我们最终可能会将此 API 升级到 jax.random 模块中,或许会在 jax.numpy.random 中提供一个 default_rng 别名。