协同 Python (Colocated Python)#
注意:协同 Python 目前是一个实验性 API。其功能和接口可能会发生变化,且不遵循标准的 JAX 兼容性政策。
协同 Python 提供了一种统一的方法,用于在与一组 JAX 设备关联的主机上运行 Python 代码。如果 JAX 设备代表本地设备,Python 代码将在本地主机上运行。如果 JAX 设备代表远程设备,Python 代码将被发送到这些远程设备所在的主机上运行。这对于在 JAX 之上构建多主机机器学习系统非常有用,该系统既可在多控制器 JAX 环境(在每个装有加速器的主机上运行 JAX 代码)中移植,也可在单控制器 JAX 环境(在单个协调其他装有加速器主机的主机上运行 JAX 代码)中移植。
协同 CPU 设备#
要使用协同 Python,第一步是获取与目标加速器设备协同的 CPU 设备。jax.experimental.colocated_python.colocated_cpu_devices 提供了一种实现此目的的标准方法。
import jax
import jax.experimental.colocated_python as colocated_python
devices = jax.devices()
cpu_devices = colocated_python.colocated_cpu_devices(devices)
print(cpu_devices)
[CpuDevice(id=0)]
和往常一样,这些 CPU 设备可以与 JAX API 一起使用。
cpu_mesh = jax.sharding.Mesh(cpu_devices, ["x"])
cpu_sharding = jax.sharding.NamedSharding(cpu_mesh, jax.P())
x = jax.device_put(1, cpu_sharding)
y = jax.jit(lambda x: x + 1)(x)
print(y)
2
协同 Python 函数#
CPU 设备也可用于通过协同 Python 运行 Python 代码。
def f(x):
return x + 1
f = colocated_python.colocated_python(f)
y = f(x)
assert y.sharding == x.sharding
print(y)
2
由于协同 Python 运行的是常规 Python 代码,你也可以执行 I/O 操作。
def f(x):
with open('/tmp/foo', 'w') as f:
f.write(str(x))
return x
f = colocated_python.colocated_python(f)
jax.block_until_ready(f(x))
Array(1, dtype=int32, weak_type=True)
请注意使用 jax.block_until_ready 来确保 Python 代码已执行完毕。原则上,协同 Python 调用可能是异步运行的,类似于 jitted 函数调用;调用会返回 JAX 数组,且不会在产出结果前阻塞。因此,如果执行完成很重要,你应该对协同 Python调用的输出进行阻塞。
存在协同 Python 调用同步运行的情况。
如果调用协同 Python 函数时没有使用“特化”(见下文),则第一次调用将同步运行。这是因为异步执行必须预先知道输出的形状(shape)和分片(sharding),而协同 Python 必须先运行一次 Python 代码以获取这些信息。
某些 JAX 后端尚未完全支持异步执行,将会回退到同步执行。
包装后的 Python 代码必须在输入和输出中使用完全相同的一组设备。这是一个类似于代表 SPMD 执行的 jitted 函数的要求。
特化 (Specialization)#
协同 Python 中的特化是一种机制,用于在无法预先推断信息,或者你希望确保协同 Python 执行完全按照指定方式进行时,提供关于协同 Python 函数的输入、输出和执行的额外信息。
首先,包装在协同 Python 中的函数具有一个 specialize 方法。该方法用于创建一个新的、使用所提供信息进行特化的协同 Python 包装函数。
out_specs_fn 是一个函数,它接收一个调用输入的 jax.ShapeDtypeStruct pytree,并返回输出所预期的 jax.ShapeDtypeStruct pytree。调用此函数类似于 jitted 函数的追踪,但该函数与原始 Python 代码是分离的。此函数在调用方一侧运行,不会在设备上执行。
def f(x):
return x + 1
f = colocated_python.colocated_python(f)
f = f.specialize(out_specs_fn=lambda x: x)
y = f(x)
assert y.sharding == x.sharding
in_specs 接收一个具体的 jax.sharding.ShapeDtypeStruct pytree(顶层为元组),作为协同 Python 函数调用的预期输入。如果必须使用特定的输入规范,或者输出规范函数只能针对具体的输入规范进行计算时,会用到它。
import jax.numpy as jnp
def f(x):
return x + 1
f = colocated_python.colocated_python(f)
f = f.specialize(
in_specs=(
# args
(
jax.ShapeDtypeStruct(
shape=(), dtype=jnp.int32, sharding=cpu_sharding
),
),
# kwargs
{},
),
out_specs_fn=lambda x: jax.ShapeDtypeStruct(
shape=(), dtype=jnp.int32, sharding=cpu_sharding
),
)
f(x) # `x` must match the input spec.
Array(2, dtype=int32, weak_type=True)
devices 指定了协同 Python 函数应该运行的设备列表。对 devices 进行特化,可以让不带输入参数的协同 Python 函数运行。
def f():
with open('/tmp/foo', 'w') as f:
f.write('foo')
return
f = colocated_python.colocated_python(f)
f = f.specialize(devices=cpu_devices)
f() # Would be an error if `f` is not specialized with ``devices``.
协同 Python 类#
协同 Python 也支持包装 Python 类。真正的实例会在与设备关联的主机上创建,而调用方将获得一个包装类,它使用协同 Python 将所有方法调用转发给真实的实例。
class Adder:
def __init__(self, increment):
print('Adder created')
self.increment = increment
def __del__(self):
print('Adder destroyed')
def add(self, x):
return x + self.increment
Adder = colocated_python.colocated_python_class(Adder)
adder = Adder(1)
x = jax.device_put(1, cpu_sharding)
y = adder.add(x)
print(y)
Adder created
2
当包装类实例被销毁时,真实实例也会被销毁。请注意,这种销毁是异步的。
del adder
Adder destroyed
协同 Python 与普通 Python 之间存在几个重要的语义差异。
协同 Python 类实例仅在首次调用任何非构造函数方法时,在与设备关联的主机上创建。在上述示例中,
Adder(1)捕获了构造函数参数1,但主机上真正的构造函数调用Adder(1)仅在第一次调用adder.add(x)时发生。这是因为在调用其方法之前,无法确定应该在哪些主机上创建Adder实例。如果同一个包装类的方法被使用不同设备的输入调用,那么真实的实例可能会在不同时间、不同主机上创建。如果第一次方法调用使用了主机 A 上的 CPU 设备,而第二次方法调用使用了主机 B 上的 CPU 设备,则真实实例将在第一次方法调用期间在主机 A 上创建,然后在第二次方法调用期间在主机 B 上创建。
协同 Python 类的方法目前尚不支持特化。该支持将在未来添加。
执行顺序与并发#
协同 Python 提供“程序顺序”执行。即使协同 Python 调用可能是异步的(返回输出 JAX 数组而不阻塞),调用也会按照用户程序中调用的顺序执行。因此,默认情况下,协同 Python 调用是顺序执行的。
协同 Python 的几个用例将受益于并发执行。例如,一个协同 Python 调用可能因为正在执行昂贵的文件读取而需要很长时间才能返回,而另一个协同 Python 调用可能需要执行与前者无关的文件写入。在这种情况下,可以期待两个调用并发运行而互不阻塞。
如果从不同的线程进行协同 Python 调用,协同 Python 会提供并发执行。例如,下面的示例将使两个协同 Python 调用并发运行。
import concurrent.futures
import time
def f(x):
time.sleep(1)
return x + 1
f = colocated_python.colocated_python(f)
f = f.specialize(out_specs_fn=lambda x: x) # Calls will be asynchronous.
with concurrent.futures.ThreadPoolExecutor(2) as executor:
fut1 = executor.submit(f, x)
fut2 = executor.submit(f, x)
# Will finish in approximately 1 second instead of 2 seconds.
jax.block_until_ready([fut1.result(), fut2.result()])
虽然来自不同线程的调用会并发运行,但在每个线程内部,程序顺序仍然适用。