JEP 9263:类型化密钥 & 可插拔 RNG#
Jake VanderPlas, Roy Frostig
2023 年 8 月
概述#
今后,JAX 中的 RNG 密钥将更加类型安全和可自定义。它将不再表示为长度为 2 的 uint32 数组,而是表示为一个具有特殊 RNG 数据类型的标量数组,该数组满足 jnp.issubdtype(key.dtype, jax.dtypes.prng_key)。
目前,仍然可以使用 jax.random.PRNGKey() 创建旧式 RNG 密钥。
>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')
从现在开始,可以使用 jax.random.key() 创建新式 RNG 密钥。
>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>
这个(标量形状的)数组的行为与其他 JAX 数组相同,只是其元素类型是密钥(及相关元数据)。我们也可以创建非标量密钥数组,例如通过将 jax.vmap() 应用到 jax.random.key()。
>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
[0 1]
[0 2]
[0 3]]
>>> key_arr.shape
(4,)
除了切换到新的构造函数外,大多数与 PRNG 相关的代码应该仍然按预期工作。您可以继续在 jax.random API 中使用密钥,例如:
# split
new_key, subkey = jax.random.split(key)
# random number generation
data = jax.random.uniform(key, shape=(5,))
但是,并非所有数值运算都适用于密钥数组。它们现在会故意引发错误。
>>> key = key + 1
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.
如果您出于某种原因需要恢复底层缓冲区(旧式密钥),可以使用 jax.random.key_data() 来实现。
>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)
对于旧式密钥,key_data() 是一个恒等操作。
这对用户意味着什么?#
对于 JAX 用户来说,此次更改目前不需要任何代码更改,但我们希望您会发现升级是值得的,并切换到使用类型化密钥。要尝试一下,请将 `jax.random.PRNGKey()` 的用法替换为 `jax.random.key()`。这可能会在您的代码中引入一些中断,这些中断可能属于以下几类:
如果您的代码对密钥执行不安全/不支持的操作(例如索引、算术、转置等;请参阅下面的类型安全性部分),此更改将捕获它。您可以更新代码以避免此类不支持的操作,或者使用
jax.random.key_data()和jax.random.wrap_key_data()以不安全的方式操作原始密钥缓冲区。如果您的代码包含关于
key.shape的显式逻辑,您可能需要更新此逻辑以考虑到尾随密钥缓冲区维度不再是形状的显式一部分的事实。如果您的代码包含关于
key.dtype的显式逻辑,您将需要升级它以使用新的公共 API 来推理 RNG 数据类型,例如dtypes.issubdtype(dtype, dtypes.prng_key)。如果您调用一个尚不支持类型化 PRNG 密钥的 JAX 库,您现在可以使用 `raw_key = jax.random.key_data(key)` 来恢复原始缓冲区,但请保留一个 TODO,以便在下游库支持类型化 RNG 密钥后将其删除。
将来某个时候,我们计划弃用 jax.random.PRNGKey() 并强制使用 jax.random.key()。
检测新式类型化密钥#
要检查一个对象是否为新式类型化 PRNG 密钥,您可以使用 jax.dtypes.issubdtype 或 jax.numpy.issubdtype。
>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False
PRNG 密钥的类型注解#
推荐的旧式和新式 PRNG 密钥的类型注解都是 jax.Array。PRNG 密钥通过其 dtype 与其他数组区分开,目前无法在类型注解中指定 JAX 数组的数据类型。以前可以使用 jax.random.KeyArray 或 jax.random.PRNGKeyArray 作为类型注解,但它们在类型检查时始终别名为 Any,因此 jax.Array 具有更高的特异性。
注意:jax.random.KeyArray 和 jax.random.PRNGKeyArray 在 JAX 版本 0.4.16 中已弃用,并在 JAX 版本 0.4.24 中删除。.
动机#
这次更改的两个主要驱动因素是可定制性和安全性。
自定义 PRNG 实现#
JAX 目前使用单一的、全局配置的 PRNG 算法。PRNG 密钥是无符号 32 位整数的向量,jax.random API 会使用它们来生成伪随机流。任何更高秩的 uint32 数组都被解释为此类密钥缓冲区的数组,其中最后一个维度代表密钥。
随着我们引入替代的 PRNG 实现,这种设计的缺点变得更加明显,这些实现必须通过设置全局或局部配置标志来选择。不同的 PRNG 实现具有不同的密钥缓冲区大小,以及不同的生成随机位的算法。使用全局标志确定此行为很容易出错,尤其是在进程范围内使用多个密钥实现时。
我们的新方法是将实现作为 PRNG 密钥类型的一部分进行携带,即作为密钥数组的元素类型。使用新的密钥 API,以下是一个在默认的三重二乘 32 位实现(用纯 Python 实现并用 JAX 编译)和非默认的 rbg 实现(对应于单个 XLA 位生成操作)下生成伪随机值的示例。
>>> key = jax.random.key(0, impl='threefry2x32') # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.947667 , 0.9785799 , 0.33229148], dtype=float32)
>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
安全使用 PRNG 密钥#
原则上,PRNG 密钥只应支持少数操作,即密钥派生(例如拆分)和随机数生成。PRNG 的设计目的是生成独立的伪随机数,前提是密钥已正确拆分并且每个密钥只使用一次。
以其他方式操作或消耗密钥数据的代码通常表明存在意外的错误,而将密钥数组表示为原始 uint32 缓冲区很容易导致这方面的滥用。以下是一些我们在实际使用中遇到的示例滥用:
密钥缓冲区索引#
访问底层整数缓冲区很容易尝试以非标准方式派生密钥,有时会产生出乎意料的糟糕后果。
# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1]) # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)
如果这个密钥是用 `random.key(999)` 创建的新式类型化密钥,那么索引到密钥缓冲区会报错。
密钥算术运算#
密钥算术运算是从其他密钥派生密钥的一种同样危险的方法。通过直接操作密钥数据来派生密钥,而不是使用 jax.random.split() 或 jax.random.fold_in(),会产生一个密钥批次,该批次——取决于 PRNG 实现——随后可能会在批次内生成相关的随机数。
# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)
使用 `random.key(0)` 创建的新式类型化密钥通过禁止对密钥进行算术运算来解决此问题。
无意中转置密钥缓冲区#
使用“原始”旧式密钥数组,很容易意外地交换批次(前导)维度和密钥缓冲区(尾随)维度。同样,这可能导致生成的伪随机性相关的密钥。我们随时间观察到的一个模式是这样的:
# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)
这里的错误很微妙。通过使用 `in_axes=1` 进行映射,此代码通过组合批次中每个密钥缓冲区的单个元素来创建新密钥。生成的密钥彼此不同,但实际上是以非标准方式“派生”的。同样,PRNG 的设计和测试并非旨在从这样的密钥批次生成独立的随机流。
使用 `random.key(0)` 创建的新式类型化密钥通过隐藏单个密钥的缓冲区表示来解决此问题,而是将密钥视为密钥数组的不透明元素。密钥数组没有要索引、转置或映射的尾随“缓冲区”维度。
密钥重用#
与 numpy.random 等基于状态的 PRNG API 不同,JAX 的函数式 PRNG 在使用后不会隐式更新密钥。
# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,)) # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))
我们正在积极开发工具来检测和防止意外的密钥重用。这项工作仍在进行中,但它依赖于类型化密钥数组。现在升级到类型化密钥可以为我们准备好在我们构建这些安全功能时引入它们。
类型化 PRNG 密钥的设计#
类型化 PRNG 密钥作为 JAX 中扩展数据类型的一种实例来实现,而新的 PRNG 数据类型是其子数据类型。
扩展数据类型#
从用户角度看,扩展数据类型 `dt` 具有以下用户可见的属性:
jax.dtypes.issubdtype(dt, jax.dtypes.extended)返回True:这是用于检测数据类型是否为扩展数据类型的公共 API。它具有一个类级别属性
dt.type,它返回numpy.generic层次结构中的一个类型类。这类似于np.dtype('int32').type返回numpy.int32,它不是一个数据类型而是一个标量类型,并且是numpy.generic的子类。与 NumPy 标量类型不同,我们不允许实例化
dt.type标量对象:这符合 JAX 将标量值表示为零维数组的决定。
从非公共实现的角度来看,扩展数据类型具有以下属性:
其类型是私有基类
jax._src.dtypes.ExtendedDtype的子类,该基类用于扩展数据类型。`ExtendedDtype` 的实例类似于 `np.dtype` 的实例,例如 `np.dtype('int32')`。它有一个私有的 `_rules` 属性,允许数据类型定义其在特定操作下的行为。例如,当 `dtype` 是扩展数据类型时,`jax.lax.full(shape, fill_value, dtype)` 将委托给 `dtype._rules.full(shape, fill_value, dtype)`。
为什么在 PRNG 之外引入扩展数据类型?我们在内部其他地方也使用了相同的扩展数据类型机制。例如,`jax._src.core.bint` 对象,一个用于动态形状实验的边界整数类型,是另一个扩展数据类型。在最近的 JAX 版本中,它满足上述属性(请参阅 jax/_src/core.py#L1789-L1802)。
PRNG 数据类型#
PRNG 数据类型定义为扩展数据类型的一个特定案例。具体来说,此更改引入了一个新的公共标量类型类 `jax.dtypes.prng_key`,它具有以下属性:
>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True
PRNG 密钥数组的数据类型具有以下属性:
>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True
除了上面为扩展数据类型概述的 `key.dtype._rules` 之外,PRNG 数据类型还定义了 `key.dtype._impl`,其中包含定义 PRNG 实现的元数据。PRNG 实现目前由非公共类 `jax._src.prng.PRNGImpl` 定义。目前,`PRNGImpl` 并不打算成为公共 API,但我们可能会很快重新审视它,以便允许完全自定义的 PRNG 实现。
进展#
以下是实现上述设计的关键 Pull Request 的非详尽列表。主要的跟踪问题是 #9263。
通过 `PRNGImpl` 实现可插拔 PRNG:#6899
实现 `PRNGKeyArray`,无数据类型:#11952
为 `PRNGKeyArray` 添加“自定义元素”数据类型属性,并附带 `_rules` 属性:#12167
将“自定义元素类型”重命名为“不透明数据类型”:#12170
重构 `bint` 以使用不透明数据类型基础设施:#12707
添加 `jax.random.key` 以直接创建类型化密钥:#16086
向 `key` 和 `PRNGKey` 添加 `impl` 参数:#16589
将“不透明数据类型”重命名为“扩展数据类型”并定义 `jax.dtypes.extended`:#16824
引入 `jax.dtypes.prng_key` 并将 PRNG 数据类型与扩展数据类型统一:#16781
添加 `jax_legacy_prng_key` 标志以支持在使用遗留(原始)PRNG 密钥时发出警告或错误:#17225