并行 Python#
注意:并行 Python 目前是一个实验性 API。其功能和接口可能会发生更改,而不会遵循标准的 JAX 兼容性策略。
并行 Python 提供了一种统一的方式来运行与一组 JAX 设备关联的主机上的 Python 代码。如果 JAX 设备代表本地设备,则 Python 代码将在本地主机上运行。如果 JAX 设备代表远程设备,则 Python 代码将被发送到这些远程设备的主机上运行。当在 JAX 上构建一个多主机 ML 系统时,该系统可以在跨多控制器 JAX 环境(在每个具有加速器的宿主机上运行 JAX 代码)以及单控制器 JAX 环境(在单个宿主机上协调具有加速器的其他宿主机)之间移植时,这很有用。
并行 CPU 设备#
要使用并行 Python,第一步是获取与目标加速器设备并行的 CPU 设备。jax.experimental.colocated_python.colocated_cpu_devices
提供了一种标准的方法来实现这一点。
import jax
import jax.experimental.colocated_python as colocated_python
devices = jax.devices()
cpu_devices = colocated_python.colocated_cpu_devices(devices)
print(cpu_devices)
[CpuDevice(id=0)]
像往常一样,CPU 设备可以与 JAX API 一起使用。
cpu_mesh = jax.sharding.Mesh(cpu_devices, ["x"])
cpu_sharding = jax.sharding.NamedSharding(cpu_mesh, jax.P())
x = jax.device_put(1, cpu_sharding)
y = jax.jit(lambda x: x + 1)(x)
print(y)
2
并行 Python 函数#
CPU 设备也可以用于使用并行 Python 运行 Python 代码。
def f(x):
return x + 1
f = colocated_python.colocated_python(f)
y = f(x)
assert y.sharding == x.sharding
print(y)
2
由于并行 Python 运行正常的 Python 代码,您也可以执行 I/O 操作
def f(x):
with open('/tmp/foo', 'w') as f:
f.write(str(x))
return x
f = colocated_python.colocated_python(f)
jax.block_until_ready(f(x))
Array(1, dtype=int32, weak_type=True)
请注意使用 jax.block_until_ready
来确保 Python 代码已完成。原则上,并行 Python 调用可能像 jitted 函数调用一样异步运行;调用将返回 JAX 数组,并且不会阻塞直到生成其输出。因此,如果执行的完成很重要,您应该阻塞来自并行 Python 调用的输出。
存在并行 Python 调用同步运行的情况。
如果并行 Python 函数在没有“特化”(参见下文)的情况下被调用,第一次调用将同步运行。这是因为必须知道输出的形状和分片才能进行异步执行,并且并行 Python 必须运行一次 Python 代码来发现此信息。
某些 JAX 后端尚未完全支持异步执行,并将回退到同步执行。
包装后的 Python 代码必须在输入和输出中使用完全相同的设备集。这与表示 SPMD 执行的 jitted 函数的要求类似。
特化#
并行 Python 中的特化是一种机制,用于在无法提前推断信息,或者您希望确保并行 Python 执行完全按照指定的方式发生时,提供有关并行 Python 函数的输入、输出和执行的额外信息。
首先,包装在并行 Python 中的函数具有 specialize
方法。此方法用于创建另一个使用所提供信息特化过的并行 Python 包装函数。
out_specs_fn
是一个函数,它接受调用输入的 jax.ShapeDtypeStruct
的 pytree,并返回输出预期的 jax.ShapeDtypeStruct
的 pytree。调用此函数类似于 jitted 函数跟踪,但此函数与原始 Python 代码是分开的。此函数在调用方端运行,而不是在设备上执行。
def f(x):
return x + 1
f = colocated_python.colocated_python(f)
f = f.specialize(out_specs_fn=lambda x: x)
y = f(x)
assert y.sharding == x.sharding
in_specs
接受并行 Python 函数调用的输入的 jax.sharding.ShapeDtypeStruct
的具体 pytree(顶层是元组)。如果必须使用某个输入规格,或者仅能针对具体输入规格计算输出规格函数,则使用此参数。
import jax.numpy as jnp
def f(x):
return x + 1
f = colocated_python.colocated_python(f)
f = f.specialize(
in_specs=(
# args
(
jax.ShapeDtypeStruct(
shape=(), dtype=jnp.int32, sharding=cpu_sharding
),
),
# kwargs
{},
),
out_specs_fn=lambda x: jax.ShapeDtypeStruct(
shape=(), dtype=jnp.int32, sharding=cpu_sharding
),
)
f(x) # `x` must match the input spec.
Array(2, dtype=int32, weak_type=True)
devices
指定并行 Python 函数应运行的设备列表。特化 devices
可以让没有输入参数的并行 Python 函数运行。
def f():
with open('/tmp/foo', 'w') as f:
f.write('foo')
return
f = colocated_python.colocated_python(f)
f = f.specialize(devices=cpu_devices)
f() # Would be an error if `f` is not specialized with ``devices``.
并行 Python 类#
并行 Python 还支持包装 Python 类。一个真实实例将在与设备关联的主机上创建,调用方将获得一个包装类,该类使用并行 Python 将所有方法调用转发给真实实例。
class Adder:
def __init__(self, increment):
print('Adder created')
self.increment = increment
def __del__(self):
print('Adder destroyed')
def add(self, x):
return x + self.increment
Adder = colocated_python.colocated_python_class(Adder)
adder = Adder(1)
x = jax.device_put(1, cpu_sharding)
y = adder.add(x)
print(y)
Adder created
2
当包装类实例被销毁时,真实实例也会被销毁。请注意,此销毁将是异步的。
del adder
Adder destroyed
并行 Python 和普通 Python 之间存在一些重要的语义差异。
并行 Python 类实例仅在与设备关联的主机上创建,当任何非构造函数方法首次被调用时。在上面的示例中,
Adder(1)
捕获了构造函数参数1
,但主机上的实际构造函数调用Adder(1)
仅在第一次调用adder.add(x)
时发生。这是因为在调用其方法之前,无法确定Adder
实例应该在哪些主机上创建。如果同一包装类的同一方法被具有不同设备输入的调用,则真实实例可能在不同时间在不同主机上创建。如果第一次方法调用使用了主机 A 上的 CPU 设备,第二次方法调用使用了主机 B 上的 CPU 设备,则真实实例将在第一次方法调用期间在主机 A 上创建,然后在第二次方法调用期间在主机 B 上创建。
并行 Python 类的方法尚不支持特化。支持将在未来添加。
执行顺序和并发#
并行 Python 提供“程序顺序”执行。即使并行 Python 调用可能是异步的(返回输出 JAX 数组而不阻塞),这些调用也将按照用户程序中调用顺序执行。因此,默认情况下,并行 Python 调用是顺序执行的。
并行 Python 的几种用例将受益于并发执行。例如,一个并行 Python 调用可能需要很长时间才能返回,因为它可能正在进行昂贵的文件读取,而另一个并行 Python 调用可能需要进行与第一个调用无关的文件写入。这种情况可以预期两个调用可以并发运行而不会相互阻塞。
如果并行 Python 调用是从不同线程发出的,并行 Python 将提供并发执行。例如,下面的示例将使两个并行 Python 调用并发运行。
import concurrent.futures
import time
def f(x):
time.sleep(1)
return x + 1
f = colocated_python.colocated_python(f)
f = f.specialize(out_specs_fn=lambda x: x) # Calls will be asynchronous.
with concurrent.futures.ThreadPoolExecutor(2) as executor:
fut1 = executor.submit(f, x)
fut2 = executor.submit(f, x)
# Will finish in approximately 1 second instead of 2 seconds.
jax.block_until_ready([fut1.result(), fut2.result()])
虽然来自不同线程的调用会并发运行,但在每个线程上,程序顺序将继续适用。