jax.pmap#
- jax.pmap(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)[source]#
支持集体运算的并行映射。
pmap()
的目的是表达单程序多数据 (SPMD) 程序。将pmap()
应用于函数将使用 XLA 编译该函数(类似于jit()
),然后在 XLA 设备(如多个 GPU 或多个 TPU 核心)上并行执行它。在语义上,它与vmap()
相当,因为这两种转换都将函数映射到数组轴上,但vmap()
通过将映射轴下推到原语操作中来向量化函数,而pmap()
则复制函数并在其自己的 XLA 设备上并行执行每个副本。映射轴大小必须小于或等于可用本地 XLA 设备的数量,如
jax.local_device_count()
返回的值(除非指定了devices
,见下文)。对于嵌套的pmap()
调用,映射轴大小的乘积必须小于或等于 XLA 设备的数量。pmap()
要求所有参与的设备都相同。例如,不可能使用pmap()
在两种不同型号的 GPU 上并行化计算。目前,同一个设备在同一个 pmap 中参与两次是错误的。多进程平台: 在多进程平台(如 TPU Pod)上,
pmap()
旨在用于 SPMD Python 程序,其中每个进程都运行相同的 Python 代码,以便所有进程都以相同的顺序运行相同的 pmapped 函数。每个进程仍应使用等于本地设备数量的映射轴大小调用 pmapped 函数(除非指定了devices
,见下文),并且通常会返回具有相同前导轴大小的数组。但是,fun
中的任何集体运算都将通过设备到设备通信在所有参与设备(包括其他进程上的设备)上计算。从概念上讲,这可以被认为是运行在跨进程分片的单个数组上的 pmap,其中每个进程“仅看到”其输入和输出的本地分片。SPMD 模型要求相同的多进程 pmap 必须在所有设备上以相同的顺序运行,但它们可以与在单个进程中运行的任意操作穿插。- 参数:
fun (Callable) – 要在参数轴上映射的函数。其参数和返回值应为数组、标量或(嵌套的)标准 Python 容器(元组/列表/字典)或其组合。由
static_broadcasted_argnums
指示的位置参数可以是任何东西,只要它们是可哈希的并且定义了相等性操作。axis_name (AxisName | None | None) – 可选,一个可哈希的 Python 对象,用于标识映射轴,以便可以应用并行集体运算。
in_axes (int | None | Sequence[Any]) – 一个非负整数、None 或其嵌套的 Python 容器,用于指定要映射的位置参数的轴。作为关键字传递的参数始终映射到其前导轴(即轴索引 0)。有关详细信息,请参阅
vmap()
。out_axes (Any) – 一个非负整数、None 或其嵌套的 Python 容器,指示映射轴应出现在输出中的位置。所有具有映射轴的输出都必须具有非 None 的
out_axes
规范(请参阅vmap()
)。static_broadcasted_argnums (int | Iterable[int]) –
一个整数或整数集合,指定要将哪些位置参数视为静态(编译时常量)。仅依赖于静态参数的操作将被常量折叠。使用这些常量的不同值调用 pmapped 函数将触发重新编译。如果使用的位置参数少于
static_broadcasted_argnums
指示的数量调用 pmapped 函数,则会引发错误。每个静态参数都将广播到所有设备。不是数组或其容器的参数必须标记为静态。默认为 ()。静态参数必须是可哈希的,这意味着实现了
__hash__
和__eq__
,并且应该是不可变的。devices (Sequence[xc.Device] | None | None) – 这是一个实验性功能,API 可能会发生变化。可选,要映射的设备序列。(可用设备可以通过 jax.devices() 检索)。在多进程设置中必须为每个进程给出相同的设备序列(因此将包括跨进程的设备)。如果指定,则映射轴的大小必须等于给定进程本地设备序列中的设备数量。尚不支持在内部或外部
pmap()
中指定devices
的嵌套pmap()
。backend (str | None | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 XLA 后端的字符串。 ‘cpu’、‘gpu’ 或 ‘tpu’。
axis_size (int | None | None) – 可选;映射轴的大小。
donate_argnums (int | Iterable[int]) –
指定哪些位置参数缓冲区“捐赠”给计算。如果您在计算完成后不再需要参数缓冲区,则可以安全地捐赠它们。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应重用捐赠给计算的缓冲区,如果您尝试这样做,JAX 将引发错误。请注意,donate_argnums 仅适用于位置参数,关键字参数将不会被捐赠。
有关缓冲区捐赠的更多详细信息,请参阅 FAQ。
global_arg_shapes (tuple[tuple[int, ...], ...] | None | None)
- 返回:
fun
的并行化版本,其参数与fun
的参数相对应,但在in_axes
指示的位置具有额外的数组轴,并且输出具有额外的前导数组轴(大小相同)。- 返回类型:
Any
例如,假设有 8 个 XLA 设备可用,
pmap()
可以用作沿前导数组轴的映射>>> import jax.numpy as jnp >>> >>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49]
当前导维度小于可用设备数量时,JAX 将仅在设备子集上运行
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(jnp.dot)(x, y) >>> print(out) [[[ 4. 9.] [ 12. 29.]] [[ 244. 345.] [ 348. 493.]] [[ 1412. 1737.] [ 1740. 2141.]]]
如果您的前导维度大于可用设备数量,您将收到错误
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) ValueError: ... requires 9 replicas, but only 8 XLA devices are available
与
vmap()
一样,在in_axes
中使用None
表示参数没有额外的轴,应该跨副本广播而不是映射>>> x, y = jnp.arange(2.), 4. >>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) >>> print(out) ([4., 5.], [8., 8.])
请注意,
pmap()
始终返回映射到其前导轴的值,等效于在vmap()
中使用out_axes=0
。除了表达纯映射之外,
pmap()
还可以用于表达通过集体运算进行通信的并行单程序多数据 (SPMD) 程序。例如>>> f = lambda x: x / jax.lax.psum(x, axis_name='i') >>> out = pmap(f, axis_name='i')(jnp.arange(4.)) >>> print(out) [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) 1.0
在此示例中,
axis_name
是一个字符串,但它可以是任何定义了__hash__
和__eq__
的 Python 对象。pmap()
的参数axis_name
命名了映射轴,以便集体运算(如jax.lax.psum()
)可以引用它。轴名称在嵌套pmap()
函数的情况下尤其重要,在这些情况下,集体运算可以在不同的轴上操作>>> from functools import partial >>> import jax >>> >>> @partial(pmap, axis_name='rows') ... @partial(pmap, axis_name='cols') ... def normalize(x): ... row_normed = x / jax.lax.psum(x, 'rows') ... col_normed = x / jax.lax.psum(x, 'cols') ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) ... return row_normed, col_normed, doubly_normed >>> >>> x = jnp.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) >>> print(row_normed.sum(0)) [ 1. 1.] >>> print(col_normed.sum(1)) [ 1. 1. 1. 1.] >>> print(doubly_normed.sum((0, 1))) 1.0
在多进程平台上,集体运算在所有设备(包括其他进程上的设备)上运行。例如,假设以下代码在两个进程上运行,每个进程有 4 个 XLA 设备
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8) >>> out = pmap(f, axis_name='i')(data) >>> print(out) [28 29 30 31] # on process 0 [32 33 34 35] # on process 1
每个进程传入一个不同的长度为 4 的数组,对应于其 4 个本地设备,并且 psum 在所有 8 个值上运算。从概念上讲,这两个长度为 4 的数组可以被认为是分片的长度为 8 的数组(在本例中等效于 jnp.arange(8)),该数组被映射,映射轴长度为 8,名称为 ‘i’。然后,每个进程上的 pmap 调用返回相应的长度为 4 的输出分片。
devices
参数可用于精确指定哪些设备用于运行并行计算。例如,再次假设单个进程有 8 个设备,以下代码定义了两个并行计算,一个在前六个设备上运行,另一个在剩余的两个设备上运行>>> from functools import partial >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6]) ... def f1(x): ... return x / jax.lax.psum(x, axis_name='i') >>> >>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:]) ... def f2(x): ... return jax.lax.psum(x ** 2, axis_name='i') >>> >>> print(f1(jnp.arange(6.))) [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] >>> print(f2(jnp.array([2., 3.]))) [ 13. 13.]