传输守卫#

JAX 可能会在类型转换和输入分片期间在主机和设备之间以及设备之间传输数据。为了记录或禁止任何意外的传输,用户可以配置 JAX 传输守卫。

JAX 传输守卫区分两种类型的传输:

  • 显式传输:`jax.device_put*()` 和 `jax.device_get()` 调用。

  • 隐式传输:其他传输(例如,打印 `DeviceArray`)。

传输守卫可以根据其守卫级别采取操作:

  • `"allow"`:静默允许所有传输(默认)。

  • `"log"`:记录并允许隐式传输。静默允许显式传输。

  • `"disallow"`:禁止隐式传输。静默允许显式传输。

  • `"log_explicit"`:记录并允许所有传输。

  • `"disallow_explicit"`:禁止所有传输。

JAX 在禁止传输时将引发 `RuntimeError`。

传输守卫使用标准的 JAX 配置系统:

  • `--jax_transfer_guard=GUARD_LEVEL` 命令行标志和 `jax.config.update("jax_transfer_guard", GUARD_LEVEL)` 将设置全局选项。

  • `with jax.transfer_guard(GUARD_LEVEL): ...` 上下文管理器将在其作用域内设置线程局部选项。

请注意,与其他 JAX 配置选项类似,新创建的线程将使用全局选项,而不是创建该线程的作用域中的任何活动线程局部选项。

传输守卫也可以更具选择性地应用,根据传输方向。标志和上下文管理器名称将以相应的传输方向作为后缀(例如,`--jax_transfer_guard_host_to_device` 和 `jax.config.transfer_guard_host_to_device`)

  • `"host_to_device"`:将 Python 值或 NumPy 数组转换为 JAX 设备上的缓冲区。

  • `"device_to_device"`:将 JAX 设备上的缓冲区复制到不同的设备。

  • `"device_to_host"`:获取 JAX 设备上的缓冲区。

无论传输守卫级别如何,从 CPU 设备获取缓冲区始终是被允许的。

以下显示了使用传输守卫的示例。

>>> jax.config.update("jax_transfer_guard", "allow")  # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x)  # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
...   print("x", x)  # x has already been fetched into the host.
...   print("y", jax.device_get(y))  # Explicit transfers are allowed.
...   try:
...     print("z", z)  # Implicit transfers are disallowed.
...     assert False, "This line is expected to be unreachable."
...   except:
...     print("z could not be fetched")  
x 1
y 2
z could not be fetched