传输保护#
JAX 可能会在类型转换和输入分片 (sharding) 期间在主机与设备之间以及不同设备之间传输数据。为了记录或禁止任何意外的传输,用户可以配置 JAX 传输保护。
JAX 传输保护区分两种类型的传输:
显式传输:
jax.device_put*()和jax.device_get()调用。隐式传输:其他传输(例如,打印
DeviceArray)。
传输保护可以根据其保护级别采取相应操作:
"allow":静默允许所有传输(默认)。"log":记录并允许隐式传输。静默允许显式传输。"disallow":禁止隐式传输。静默允许显式传输。"log_explicit":记录并允许所有传输。"disallow_explicit":禁止所有传输。
当禁止某项传输时,JAX 会引发 RuntimeError。
传输保护使用标准的 JAX 配置系统:
--jax_transfer_guard=GUARD_LEVEL命令行标志和jax.config.update("jax_transfer_guard", GUARD_LEVEL)将设置全局选项。with jax.transfer_guard(GUARD_LEVEL): ...上下文管理器将在该上下文管理器的作用域内设置线程局部选项。
请注意,与其他 JAX 配置选项类似,新创建的线程将使用全局选项,而不是使用创建该线程所在作用域的任何活动线程局部选项。
传输保护还可以根据传输方向进行更具选择性的应用。标志和上下文管理器的名称后缀为相应的传输方向(例如,--jax_transfer_guard_host_to_device 和 jax.config.transfer_guard_host_to_device)。
"host_to_device":将 Python 值或 NumPy 数组转换为 JAX 设备上缓冲区。"device_to_device":将 JAX 设备上缓冲区复制到另一台设备。"device_to_host":获取 JAX 设备上缓冲区。
无论传输保护级别如何,在 CPU 设备上获取缓冲区始终是被允许的。
以下是使用传输保护的示例。
>>> jax.config.update("jax_transfer_guard", "allow") # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x) # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
... print("x", x) # x has already been fetched into the host.
... print("y", jax.device_get(y)) # Explicit transfers are allowed.
... try:
... print("z", z) # Implicit transfers are disallowed.
... assert False, "This line is expected to be unreachable."
... except:
... print("z could not be fetched")
x 1
y 2
z could not be fetched