jax.experimental.key_reuse
模块#
实验性密钥复用检查#
该模块包含实验性功能,用于检测 JAX 程序中随机密钥的复用。它正在积极开发中,并且这里的 API 可能会发生变化。以下用法需要 JAX 版本 0.4.26 或更高。
密钥复用检查可以通过 jax_debug_key_reuse
配置启用。可以通过以下方式全局设置
>>> jax.config.update('jax_debug_key_reuse', True)
或者可以通过 jax.debug_key_reuse()
上下文管理器局部启用。启用后,两次使用相同的密钥将导致 KeyReuseError
>>> import jax
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... val1 = jax.random.normal(key)
... val2 = jax.random.normal(key)
Traceback (most recent call last):
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
密钥复用检查器目前处于实验阶段,但将来我们可能会默认启用它。