分布式数据加载#
本高级指南演示了如何在 多主机或多进程环境 中运行 JAX 时执行分布式数据加载,其中 JAX 计算所需的数据分布在多个进程中。本文档涵盖了思考分布式数据加载的总体方法,然后介绍了如何将其应用于数据并行(更简单)和模型并行(更复杂)的工作负载。
与替代方法相比,分布式数据加载通常更高效(数据分布在进程之间),但也更复杂。例如:1) 在单个进程中加载完整的全局数据,然后将其拆分并通过 RPC 发送给其他进程;2) 在所有进程中加载完整的全局数据,然后仅在每个进程中使用所需的部分。加载完整的全局数据通常更简单,但成本更高。例如,在机器学习中,训练循环可能会在等待数据时被阻塞,并且每个进程都会使用额外的网络带宽。
注意
在使用分布式数据加载时,重要的是每个设备(例如,每个 GPU 或 TPU)都能访问运行计算所需的数据分片。这通常会使分布式数据加载比上述替代方法更复杂,并且更难正确实现。如果错误的数据分片最终出现在错误的设备上,计算仍然可以运行而不会出错,因为计算无法知道输入数据“应该”是什么。但是,最终结果通常会不正确,因为输入数据与预期的不同。
加载 jax.Array
的通用方法#
考虑一种从 JAX 未产生的原始数据创建单个 jax.Array
的情况。这些概念不仅适用于加载批量数据记录,还适用于任何非 JAX 计算直接产生的多进程 jax.Array
。例如:1) 从检查点加载模型权重;或 2) 加载大型空间分片的图像。
每个 jax.Array
都有一个关联的 Sharding
,它描述了每个全局设备所需全局数据的分片。当您从头开始创建 jax.Array
时,您还需要创建其 Sharding
。这样 JAX 才能理解数据如何在设备之间布局。您可以创建任何您想要的 Sharding
。实际上,您通常会根据您正在实现的并行策略来选择 Sharding
(您将在本指南的后面更详细地了解数据并行和模型并行)。您还可以根据原始数据将在每个进程中如何生成来选择 Sharding
。
一旦定义了 Sharding
,您就可以使用 addressable_devices()
来提供在当前进程中加载数据所需的设备列表。(注意:“可寻址设备”比“本地设备”是一个更通用的术语。目标是确保每个进程的数据加载器都为其本地设备提供正确的数据。)
示例#
例如,考虑一个您需要在 4 个进程(每个进程 2 个设备,总共 8 个设备)上分片的 (64, 128)
jax.Array
。这将导致 8 个唯一的数据分片,每个设备一个。有许多方法可以分片此 jax.Array
。您可以沿 jax.Array
的第二个维度执行一维分片,使每个设备获得一个 (64, 16)
的分片,如下所示:
在上图中,每个数据分片都有自己的颜色,以指示哪个进程需要加载该分片。例如,您假设进程 0
的 2 个设备包含分片 A
和 B
,对应于全局数据的第一个 (64, 32)
部分。
您可以选择不同的分片到设备的分布。例如:
这是另一个示例 — 二维分片:
无论 jax.Array
如何分片,您都必须确保每个进程的数据加载器都能获得/加载全局数据所需的正确分片。有几种高级方法可以实现这一点:1) 在每个进程中加载全局数据;2) 使用每个设备的管道加载数据;3) 使用合并的每个进程的管道加载数据;4) 以方便的方式加载数据,然后在计算内部重分片。
选项 1:在每个进程中加载全局数据#
使用此选项,每个进程
加载所需的完整值;然后
仅将所需的分片传输到该进程的本地设备。
这不是一种高效的分布式数据加载方法,因为每个进程都会丢弃其本地设备不需要的数据,并且总摄入数据量可能高于必要。但此选项有效且相对易于实现,对于某些工作负载(例如,如果全局数据很小)来说,性能开销可能还可以接受。
选项 2:使用每个设备的管道加载数据#
在此选项中,每个进程为其每个本地设备设置一个数据加载器(即,每个设备都有自己的数据加载器,仅用于它所需的数据分片)。
这在加载的数据方面是高效的。有时,将每个设备视为独立于进程的所有本地设备(请参阅下面的选项 3:使用合并的每个进程的管道加载数据)会更简单。但是,多个并发数据加载器有时会导致性能问题。
选项 3:使用合并的每个进程的管道加载数据#
如果您选择此选项,每个进程
设置一个数据加载器,用于加载其所有本地设备所需的数据;然后
在传输到每个本地设备之前对本地数据进行分片。
这是最高效的分布式加载方式。但同时它也是最复杂的,因为需要逻辑来确定每个设备需要哪些数据,并创建一个仅加载所有这些数据(理想情况下,不加载任何额外的多余数据)的单一数据加载器。
选项 4:以方便的方式加载数据,在计算内部重分片#
这个选项更难解释,但通常比上述选项(1 到 3)更容易实现。
想象一个场景,在这种场景下,很难甚至不可能设置加载您需要数据的精确数据加载器(无论是每个设备还是每个进程的数据加载器)。但是,仍然可以为每个进程设置一个数据加载器,该加载器加载数据量的 1 / num_processes
,只是分片方式不正确。
然后,继续前面的二维分片示例,假设每个进程更容易加载数据的单个列:
然后,您可以创建一个具有表示按列数据加载的 Sharding
的 jax.Array
,将其直接传递给计算,并使用 jax.lax.with_sharding_constraint()
立即将按列分片的数据重分片到所需的 Sharding
。而且,由于数据是在计算内部重分片的,所以它将在加速器通信链路(例如,TPU ICI 或 NVLink)上重分片。
此选项 4 具有与选项 3(使用合并的每个进程的管道加载数据)类似的优点:
每个进程仍然只有一个数据加载器;并且
全局数据仅在所有进程中加载一次;并且
全局数据具有在数据加载方式上提供更多灵活性的附加好处。
但是,此方法使用加速器互连带宽来执行重分片,这可能会减慢某些工作负载的速度。选项 4 还要求输入数据表示为独立的 Sharding
,除了目标 Sharding
之外。
复制#
复制描述了一个过程,其中多个设备具有相同的数据分片。上面提到的通用选项(选项 1 到 4)仍然适用于复制。唯一的区别是某些进程最终可能会加载相同的数据分片。本节介绍完全复制和部分复制。
完全复制#
完全复制是一个过程,其中所有设备都拥有数据的完整副本(即,数据“分片”是整个数组值)。
在下面的示例中,由于总共有 8 个设备(每个进程 2 个),因此您将获得数据的 8 个副本。每个数据副本都是未分片的,即副本存在于单个设备上:
部分复制#
部分复制描述了一个过程,其中有多个数据副本,每个副本跨多个设备分片。对于给定的数组值,通常有许多可能的实现部分复制的方法(注意:对于给定的数组形状,总有一个完全复制的 Sharding
)。
下面是两个可能的示例。
在下面的第一个示例中,每个副本跨进程的两个本地设备分片,总共有 4 个副本。这意味着每个进程都需要加载全局数据的完整副本,因为其本地设备将拥有数据的完整副本。
在下面的第二个示例中,每个副本仍然跨两个设备分片,但每个设备对分布在两个不同的进程中。进程 0
(粉色)和进程 1
(黄色)都需要加载数据的第一行,而进程 2
(绿色)和进程 3
(蓝色)都需要加载数据的第二行。
现在您已经了解了创建 jax.Array
的高级选项,让我们将它们应用于机器学习应用程序的数据加载。
数据并行#
在纯数据并行(无模型并行)中:
您在每个设备上复制模型;然后
每个模型副本(即每个设备)接收不同的每个副本数据批次。
当将输入数据表示为单个 jax.Array
时,该数组包含此步骤所有副本的数据(这称为全局批次),其中 jax.Array
的每个分片包含单个每个副本的批次。您可以将其表示为跨所有设备的 1D 分片(请查看下面的示例) — 换句话说,全局批次由在批次轴上连接起来的所有每个副本的批次组成。
应用此框架,您可能会得出结论,进程 0
应该获得全局批次的第一个四分之一(8 个中的 2 个),而进程 1
应该获得第二个,以此类推。
但是,您如何知道第一个四分之一是多少?您如何确保进程 0
获得第一个四分之一?幸运的是,关于数据并行有一个非常重要的技巧,这意味着您不必回答这些问题,并且可以使整个设置更简单。
关于数据并行的重要技巧#
技巧在于,您不需要关心哪个每个副本的批次落在哪个副本上。因此,哪个进程加载批次并不重要。原因是,由于每个设备对应一个执行相同操作的模型副本,所以哪个设备获得全局批次中的哪个每个副本的批次并不重要。
这意味着您可以自由地重新排列全局批次中的每个副本批次。换句话说,您可以自由地随机化每个设备获得的数据分片。
例如
通常,重新排列 jax.Array
的数据分片(如上所示)不是一个好主意 — 您实际上是在排列 jax.Array
的值!然而,对于数据并行,全局批次的顺序没有意义,您可以自由地重新排列全局批次中的每个副本批次,如前所述。
这简化了数据加载,因为它意味着每个设备只需要一个独立的每个副本批次流,这在大多数数据加载器中可以通过为每个进程创建独立的管道并将生成的每个进程批次分块成每个副本批次来轻松实现。
这是选项 2:合并的每个进程的管道加载数据的一个实例。您也可以使用其他选项(例如本文档前面介绍的 0、1 和 3),但这个选项相对简单且高效。
以下是使用 tf.data 实现此设置的示例:
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: setup the Dataset for pure data parallelism (do once)
################################################################################
# Fake example data (replace with your Dataset)
ds = tf.data.Dataset.from_tensor_slices(
[np.ones((16, 3)) * i for i in range(100)])
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step). This can be used with batches
# produced by different data loaders as well!
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
mesh = jax.make_mesh((jax.device_count(),), ('batch',))
sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('batch'))
global_batch_array = jax.make_array_from_process_local_data(
sharding, per_process_batch)
数据 + 模型并行#
在模型并行中,您在多个设备上分片每个模型副本。如果您使用纯模型并行(无数据并行):
只有一个模型副本分片在所有设备上;然后
数据(通常)完全复制在所有设备上。
本指南考虑了您同时使用数据和模型并行的情况:
您将多个模型副本中的每个副本分片到多个设备上;然后
您将数据部分复制到每个模型副本上 — 同一模型副本中的每个设备获得相同的每个副本批次,而跨模型副本的设备获得不同的每个副本批次。
进程内的模型并行#
为了数据加载的目的,最简单的方法可能是将每个模型副本分片到单个进程的本地设备上。
在此示例中,我们改为使用 2 个进程(每个进程 4 个设备),而不是 4 个进程(每个进程 2 个设备)。考虑一个场景,其中每个模型副本分片到单个进程的 2 个本地设备上。这会导致每个进程中有 2 个模型副本,总共 4 个模型副本,如下所示:
这里,再次,输入数据表示为具有 1D 分片的单个 jax.Array
,其中每个分片是每个副本的批次,但有一个例外:
与纯数据并行情况不同,您引入了部分复制,并创建了 1D 分片全局批次的 2 个副本。
这是因为每个模型副本由 2 个设备组成,每个设备都需要一个每个副本批次的副本。
将每个模型副本保留在单个进程内可以使事情更简单,因为您可以重用上面描述的纯数据并行设置,只是您还需要复制每个副本的批次。
注意
将每个副本的批次复制到正确的设备上也非常重要!虽然关于数据并行的那个非常重要的技巧意味着您不关心哪个批次最终落在哪个副本上,但您确实关心单个副本只获得一个批次。
例如,这是可以的:
但是,如果您不注意将每个批次加载到哪个本地设备上,您可能会意外地创建未复制的数据,即使 Sharding
(和并行策略)表明数据是复制的。
如果您意外地创建了一个应在单个进程内复制但数据未复制的 jax.Array
,JAX 将引发错误(这并不总是适用于跨进程的模型并行;请参阅下一节)。
以下是如何使用 tf.data
实现每个进程的模型并行和数据并行的方法:
import jax
import tensorflow as tf
import numpy as np
################################################################################
# Step 1: Set up the Dataset with a different data shard per-process (do once)
# (same as for pure data parallelism)
################################################################################
# Fake example data (replace with your Dataset)
per_process_batches = [np.ones((16, 3)) * i for i in range(100)]
ds = tf.data.Dataset.from_tensor_slices(per_process_batches)
ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())
################################################################################
# Step 2: Create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step)
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()
num_model_replicas_per_process = 2 # set according to your parallelism strategy
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()
# Create an example `Mesh` for per-process data parallelism. Make sure all devices
# are grouped by process, and then resize so each row is a model replica.
mesh_devices = np.array([jax.local_devices(process_idx)
for process_idx in range(jax.process_count())])
mesh_devices = mesh_devices.reshape(num_model_replicas_total, -1)
# Double check that each replica's devices are on a single process.
for replica_devices in mesh_devices:
num_processes = len(set(d.process_index for d in replica_devices))
assert num_processes == 1
mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])
# Shard the data across model replicas. You don't shard across the
# data_parallelism mesh axis, meaning each per-replica shard will be replicated
# across that axis.
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("model_replicas"))
global_batch_array = jax.make_array_from_process_local_data(
sharding, per_process_batch)
跨进程的模型并行#
当模型副本分布在进程之间时,情况会变得更有趣,无论是因为:
单个副本无法在单个进程内容纳;或者
设备分配不是这样设置的。
例如,回到前面 4 个进程(每个进程 2 个设备)的设置,如果您这样分配设备到副本:
这与前面每个进程的模型并行示例是相同的并行策略 — 4 个模型副本,每个副本跨 2 个设备分片。唯一的区别是设备分配 — 每个副本的两个设备分布在不同的进程中,并且每个进程只负责每个副本批次的单个副本(但对于两个副本)。
将模型副本分布在进程之间可能看起来是任意且不必要的(在这个示例中,可以说是的),但实际部署可能会采用这种设备分配方式,以最佳地利用设备之间的通信链路。
数据加载现在变得更加复杂,因为需要在进程之间进行一些额外的协调。在纯数据并行和每个进程的模型并行情况下,只关注每个进程加载唯一的数据流很重要。现在,某些进程必须加载相同的数据,而某些进程必须加载不同的数据。在上例中,进程 0
和 2
(分别为粉色和绿色)必须加载相同的 2 个每个副本批次,而进程 1
和 3
(分别为黄色和蓝色)也必须加载相同的 2 个每个副本批次(但与进程 0
和 2
的批次不同)。
此外,重要的是每个进程不要混淆其 2 个每个副本的批次。虽然您不关心哪个批次落在哪个副本上(关于数据并行的那个非常重要的技巧),但您需要关心一个副本只获得一个批次。例如,这会很糟糕:
注意
截至 2023 年 8 月,JAX 无法检测跨进程的 jax.Array
分片是否应该复制但实际上没有复制,并且在计算运行时会产生错误的结果。所以要小心不要这样做!
为了在每个设备上获得正确的每个副本批次,您需要将全局输入数据表示为以下 jax.Array
: