JEP 9263: 类型化密钥与可插拔RNGs#

Jake VanderPlas, Roy Frostig

2023 年 8 月

概述#

今后,JAX 中的 RNG 密钥将更加类型安全且可自定义。单个 PRNG 密钥不再由长度为 2 的 uint32 数组表示,而是由具有特殊 RNG 数据类型(满足 jnp.issubdtype(key.dtype, jax.dtypes.prng_key))的标量数组表示。

目前,仍然可以使用 jax.random.PRNGKey() 创建旧式 RNG 密钥。

>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')

从现在开始,可以使用 jax.random.key() 创建新式 RNG 密钥。

>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>

这种(标量形状的)数组行为与任何其他 JAX 数组相同,但其元素类型是密钥(及相关元数据)。我们也可以创建非标量密钥数组,例如通过将 jax.vmap() 应用于 jax.random.key()

>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
 [0 1]
 [0 2]
 [0 3]]
>>> key_arr.shape
(4,)

除了切换到新的构造函数外,大多数 PRNG 相关代码应继续按预期工作。您可以像以前一样在 jax.random API 中继续使用密钥;例如

# split
new_key, subkey = jax.random.split(key)

# random number generation
data = jax.random.uniform(key, shape=(5,))

然而,并非所有数值运算都适用于密钥数组。它们现在会故意引发错误。

>>> key = key + 1  
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.

如果由于某种原因需要恢复底层缓冲区(旧式密钥),可以使用 jax.random.key_data()

>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)

对于旧式密钥,key_data() 是一个恒等操作。

这对用户意味着什么?#

对于 JAX 用户,此更改目前不需要任何代码更改,但我们希望您会发现此升级是值得的并切换到使用类型化密钥。要尝试此功能,请将 `jax.random.PRNGKey()` 的使用替换为 `jax.random.key()`。这可能会在您的代码中引入属于以下几类的问题:

  • 如果您的代码对密钥执行不安全/不受支持的操作(例如索引、算术运算、转置等;请参见下面的类型安全部分),此更改将检测到它。您可以更新代码以避免此类不受支持的操作,或者使用 jax.random.key_data()jax.random.wrap_key_data() 以不安全的方式操作原始密钥缓冲区。

  • 如果您的代码包含关于 key.shape 的显式逻辑,您可能需要更新此逻辑,以考虑尾随密钥缓冲区维度不再是形状的显式部分这一事实。

  • 如果您的代码包含关于 key.dtype 的显式逻辑,您将需要升级它以使用新的公共 API 来推断 RNG 数据类型,例如 dtypes.issubdtype(dtype, dtypes.prng_key)

  • 如果您调用的基于 JAX 的库尚未处理类型化 PRNG 密钥,您目前可以使用 raw_key = jax.random.key_data(key) 来恢复原始缓冲区,但请保留一个 TODO,以便在下游库支持类型化 RNG 密钥后将其删除。

未来某个时候,我们计划弃用 jax.random.PRNGKey() 并要求使用 jax.random.key()

检测新式类型化密钥#

要检查对象是否为新式类型化 PRNG 密钥,可以使用 jax.dtypes.issubdtypejax.numpy.issubdtype

>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False

PRNG 密钥的类型注解#

旧式和新式 PRNG 密钥的推荐类型注解是 jax.Array。PRNG 密钥与其他数组的区别在于其 dtype,目前无法在类型注解中指定 JAX 数组的 `dtype`。以前可以使用 jax.random.KeyArrayjax.random.PRNGKeyArray 作为类型注解,但这些在类型检查下始终被别名为 Any,因此 jax.Array 具有更高的特异性。

注意:jax.random.KeyArrayjax.random.PRNGKeyArray 在 JAX 0.4.16 版本中已弃用,并在 JAX 0.4.24 版本中移除。.

JAX 库作者须知#

如果您维护一个基于 JAX 的库,您的用户也是 JAX 用户。请注意,JAX 目前将继续在 jax.random 中支持“原始”旧式密钥,因此调用者可能会期望它们在所有地方都仍然被接受。如果您更喜欢在您的库中要求使用新式类型化密钥,那么您可能希望通过以下方式进行检查:

from jax import dtypes

def ensure_typed_key_array(key: Array) -> Array:
  if dtypes.issubdtype(key.dtype, dtypes.prng_key):
    return key
  else:
    raise TypeError("New-style typed JAX PRNG keys required")

动机#

此次更改的两个主要驱动因素是可定制性和安全性。

自定义 PRNG 实现#

JAX 目前使用单一的、全局配置的 PRNG 算法。PRNG 密钥是无符号 32 位整数的向量,JAX.random API 会使用它来生成伪随机流。任何更高维的 `uint32` 数组都被解释为此类密钥缓冲区的数组,其中尾随维度表示密钥。

随着我们引入替代 PRNG 实现(必须通过设置全局或局部配置标志来选择),这种设计的缺点变得更加明显。不同的 PRNG 实现具有不同大小的密钥缓冲区和不同的随机位生成算法。使用全局标志确定此行为容易出错,特别是在进程中使用了多个密钥实现时。

我们的新方法是将实现作为 PRNG 密钥类型的一部分,即作为密钥数组的元素类型。使用新的密钥 API,这里有一个在默认 `threefry2x32` 实现(用纯 Python 实现并用 JAX 编译)和非默认 `rbg` 实现(对应于单个 XLA 随机位生成操作)下生成伪随机值的示例:

>>> key = jax.random.key(0, impl='threefry2x32')  # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.947667  , 0.9785799 , 0.33229148], dtype=float32)

>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)

PRNG 密钥的安全使用#

原则上,PRNG 密钥只旨在支持少数操作,即密钥派生(例如拆分)和随机数生成。PRNG 旨在生成独立的伪随机数,前提是密钥已正确拆分且每个密钥只使用一次。

以其他方式操作或使用密钥数据的代码通常表明存在意外错误,并且将密钥数组表示为原始 `uint32` 缓冲区使得沿着这些方向的误用变得容易。以下是我们实际遇到的一些误用示例:

密钥缓冲区索引#

访问底层整数缓冲区使得尝试以非标准方式派生密钥变得容易,有时会带来意想不到的严重后果。

# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1])  # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)

如果此密钥是使用 random.key(999) 创建的新式类型化密钥,则对密钥缓冲区进行索引将改为报错。

密钥算术运算#

密钥算术运算是另一种从其他密钥派生密钥的危险方式。通过直接操作密钥数据而非使用 jax.random.split()jax.random.fold_in() 来派生密钥,会生成一批密钥,这些密钥(取决于 PRNG 实现)可能会在批次内生成相关的随机数。

# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)

使用 random.key(0) 创建的新式类型化密钥通过禁止对密钥进行算术运算来解决此问题。

密钥缓冲区的意外转置#

使用“原始”旧式密钥数组,很容易意外地交换批处理(前导)维度和密钥缓冲区(尾随)维度。这再次可能导致密钥产生相关的伪随机性。我们长期以来看到的一种模式归结为:

# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)

这里的错误很微妙。通过对 in_axes=1 进行映射,此代码通过组合批处理中每个密钥缓冲区的一个元素来创建新密钥。生成的密钥彼此不同,但实际上是以非标准方式“派生”的。同样,PRNG 并非设计或测试用于从此类密钥批次生成独立的随机流。

使用 random.key(0) 创建的新式类型化密钥通过隐藏单个密钥的缓冲区表示来解决此问题,转而将密钥视为密钥数组的不透明元素。密钥数组没有可供索引、转置或映射的尾随“缓冲区”维度。

密钥重用#

numpy.random 等基于状态的 PRNG API 不同,JAX 的函数式 PRNG 在密钥使用后不会隐式更新。

# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,))  # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))

我们正在积极开发工具来检测和防止意外的密钥重用。这仍在进行中,但它依赖于类型化密钥数组。现在升级到类型化密钥为我们将来引入这些安全功能奠定了基础。

类型化 PRNG 密钥的设计#

类型化 PRNG 密钥在 JAX 中实现为扩展数据类型(extended dtypes)的一个实例,其中新的 PRNG 数据类型是其子数据类型。

扩展数据类型#

从用户的角度来看,扩展数据类型 `dt` 具有以下用户可见的属性:

  • jax.dtypes.issubdtype(dt, jax.dtypes.extended) 返回 True:这是应该用于检测数据类型是否为扩展数据类型的公共 API。

  • 它有一个类级别属性 dt.type,它返回 numpy.generic 层次结构中的一个类型类。这类似于 np.dtype('int32').type 返回 numpy.int32 的方式,numpy.int32 不是数据类型,而是一个标量类型,并且是 numpy.generic 的子类。

  • 与 NumPy 标量类型不同,我们不允许实例化 dt.type 标量对象:这与 JAX 将标量值表示为零维数组的决定一致。

从非公共实现的角度来看,扩展数据类型具有以下属性:

  • 它的类型是私有基类 jax._src.dtypes.ExtendedDtype 的子类,它是用于扩展数据类型的非公共基类。ExtendedDtype 的实例类似于 np.dtype 的实例,例如 np.dtype('int32')

  • 它有一个私有属性 _rules,允许数据类型定义其在特定操作下的行为。例如,当 dtype 是扩展数据类型时,jax.lax.full(shape, fill_value, dtype) 将委托给 dtype._rules.full(shape, fill_value, dtype)

为什么在 PRNG 之外普遍引入扩展数据类型?我们在内部其他地方重用了相同的扩展数据类型机制。例如,用于动态形状实验的 jax._src.core.bint 对象(一种有界整数类型)是另一种扩展数据类型。在最近的 JAX 版本中,它满足上述属性(请参阅 jax/_src/core.py#L1789-L1802)。

PRNG 数据类型#

PRNG 数据类型被定义为扩展数据类型的一个特殊情况。具体来说,此更改引入了一个新的公共标量类型类 `jax.dtypes.prng_key`,它具有以下属性:

>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True

PRNG 密钥数组因此具有以下属性的数据类型:

>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True

除了通常为扩展数据类型概述的 key.dtype._rules 之外,PRNG 数据类型还定义了 key.dtype._impl,其中包含定义 PRNG 实现的元数据。PRNG 实现目前由非公共类 jax._src.prng.PRNGImpl 定义。目前,PRNGImpl 不打算作为公共 API,但我们可能会很快重新审视这一点,以允许完全自定义的 PRNG 实现。

进展#

以下是实现上述设计的一些主要拉取请求(Pull Request)的非详尽列表。主要跟踪问题是 #9263

  • 通过 PRNGImpl 实现可插拔 PRNG:#6899

  • 实现 PRNGKeyArray,不带 `dtype`:#11952

  • PRNGKeyArray 添加带 `_rules` 属性的“自定义元素”数据类型属性:#12167

  • 将“自定义元素类型”重命名为“不透明数据类型”:#12170

  • 重构 bint 以使用不透明数据类型基础设施:#12707

  • 添加 jax.random.key 以直接创建类型化密钥:#16086

  • keyPRNGKey 添加 impl 参数:#16589

  • 将“不透明数据类型”重命名为“扩展数据类型”并定义 jax.dtypes.extended#16824

  • 引入 jax.dtypes.prng_key 并统一 PRNG 数据类型与扩展数据类型:#16781

  • 添加 jax_legacy_prng_key 标志,以支持在使用旧版(原始)PRNG 密钥时发出警告或报错:#17225