分布式数据加载#

本高级指南演示了如何执行分布式数据加载 — 当你在多主机或多进程环境中运行 JAX,并且 JAX 计算所需的数据分布在多个进程中时。本文档涵盖了分布式数据加载的整体思路,以及如何将其应用于数据并行(更简单)和模型并行(更复杂)的工作负载。

分布式数据加载通常比其他替代方案更高效(数据在进程间分割),但也更复杂,这些替代方案包括:1) 在单个进程中加载完整的全局数据,然后将其分割并通过 RPC 发送所需部分到其他进程;2) 在所有进程中加载完整的全局数据,并且每个进程只使用所需的部分。加载完整的全局数据通常更简单,但成本更高。例如,在机器学习中,训练循环可能因等待数据而被阻塞,并且每个进程都会使用额外的网络带宽。

注意

使用分布式数据加载时,重要的是每个设备(例如,每个 GPU 或 TPU)都能访问其运行计算所需的输入数据分片。这通常是使分布式数据加载更复杂且难以正确实现的原因(与上述替代方案相比)。如果错误的数据分片最终到达错误的设备,计算仍然可以无错误地运行,因为计算无法知道输入数据“应该”是什么。然而,最终结果通常会不正确,因为输入数据与预期不同。

加载 jax.Array 的通用方法#

考虑从非 JAX 产生原始数据创建单个 jax.Array 的情况。这些概念不仅适用于加载批处理数据记录,还适用于任何未直接由 JAX 计算产生的多进程 jax.Array。示例包括:1) 从检查点加载模型权重;或 2) 加载大型空间分片图像。

每个 jax.Array 都带有一个相关的 Sharding,它描述了每个全局设备所需的全局数据分片。当你从头开始创建一个 jax.Array 时,你还需要创建它的 Sharding。JAX 就是通过这种方式了解数据在设备上的布局。你可以创建任何你想要的 Sharding。实际上,你通常会根据你正在实现的并行策略类型(你将在本指南后面更详细地了解数据并行和模型并行)来选择一个 Sharding。你还可以根据原始数据将在每个进程中如何生成来选择一个 Sharding

定义了 Sharding 之后,你可以使用 addressable_devices() 来提供当前进程内加载数据所需的设备列表。(注意:“可寻址设备”是“本地设备”的更通用版本。目标是确保每个进程的数据加载器将正确的数据提供给该进程的所有本地设备。)

示例#

例如,考虑一个 (64, 128)jax.Array,你需要将其分片到 4 个进程,每个进程有 2 个设备(总共 8 个设备)。这将产生 8 个唯一的数据分片,每个设备一个。有多种方法可以对这个 jax.Array 进行分片。你可以在 jax.Array 的第二维上执行 1D 分片,使每个设备获得一个 (64, 16) 的分片,如下所示:

8 unique data shards

在上面的图中,每个数据分片都有自己的颜色,以指示哪个进程需要加载该分片。例如,你假设进程 0 的 2 个设备包含分片 AB,对应于全局数据的前 (64, 32) 部分。

你可以选择不同的分片到设备的分布方式。例如:

8 unique data shards - different distribution

这是另一个例子——一个 2D 分片

2D sharding

无论 jax.Array 如何分片,你都必须确保每个进程的数据加载器都被提供/加载了全局数据所需的(那些)分片。实现此目的有几种高级方法:1) 在每个进程中加载全局数据;2) 使用每个设备的并行数据管道;3) 使用合并的每个进程数据管道;4) 以方便的方式加载数据,然后在计算内部进行重新分片。

选项 1:在每个进程中加载全局数据#

Loading the global data in each process

使用此选项,每个进程

  1. 加载所需的完整值;并且

  2. 仅将所需的分片传输到该进程的本地设备。

这不是一种高效的分布式数据加载方法,因为每个进程都会丢弃其本地设备不需要的数据,并且总摄取数据量可能高于必要量。但是此选项可行且相对简单易实现,并且对于某些工作负载(例如,如果全局数据很小)来说,性能开销可能是可接受的。

选项 2:使用每个设备的并行数据管道#

Using a per-device data pipeline

在此选项中,每个进程为其每个本地设备设置一个数据加载器(也就是说,每个设备都有自己的数据加载器,仅用于其所需的数据分片)。

这在数据加载方面是高效的。有时,独立考虑每个设备而不是一次性考虑一个进程的所有本地设备也会更简单(参见下面的“选项 3:使用合并的每个进程数据管道”)。然而,拥有多个并发数据加载器有时可能会导致性能问题。

选项 3:使用合并的每个进程数据管道#

Using a consolidated per-process data pipeline

如果你选择此选项,每个进程

  1. 设置一个单一数据加载器,加载其所有本地设备所需的数据;然后

  2. 在传输到每个本地设备之前对本地数据进行分片。

这是执行分布式加载最有效的方法。然而,它也是最复杂的方法,因为需要逻辑来弄清楚每个设备需要哪些数据,并创建一个只加载所有这些数据(理想情况下,不加载任何额外数据)的单一数据加载。

选项 4:以方便的方式加载数据,然后在计算内部进行重新分片#

Loading  data in some convenient way, reshard inside computation

此选项更具挑战性,难以解释,但通常比上述选项(1 到 3)更易于实现。

设想一个场景,设置数据加载器来精确加载所需数据(无论是针对每个设备还是每个进程的加载器)是困难甚至不可能的。然而,仍然可能为每个进程设置一个数据加载器,加载 1 / num_processes 的数据,只是分片方式不正确。

然后,继续你之前的 2D 分片示例,假设每个进程加载一列数据更容易。

然后,你可以创建一个带有 Shardingjax.Array,该 Sharding 表示每列数据,并将其直接传递到计算中,然后使用 jax.lax.with_sharding_constraint() 立即将列分片的输入重新分片到所需的形状。由于数据在计算内部重新分片,它将通过加速器通信链路(例如,TPU ICI 或 NVLink)进行重新分片。

选项 4 与选项 3(使用合并的每个进程数据管道)具有相似的优势:

  • 每个进程仍然只有一个数据加载器;并且

  • 全局数据在所有进程中只加载一次;并且

  • 全局数据在加载方式上具有额外的好处,提供了更大的灵活性。

然而,这种方法会使用加速器互连带宽来执行重新分片,这可能会减慢某些工作负载。选项 4 还需要将输入数据表示为单独的 Sharding,除了目标 Sharding 之外。

复制#

复制描述的是多个设备拥有相同数据分片的过程。上述通用选项(选项 1 到 4)仍然适用于复制。唯一的区别是某些进程最终可能会加载相同的数据分片。本节描述完全复制和部分复制。

完全复制#

完全复制是指所有设备都拥有数据的完整副本(即,数据“分片”是整个数组值)的过程。

在下面的例子中,由于总共有 8 个设备(每个进程 2 个),你最终会得到 8 份完整数据副本。每份数据副本都是未分片的,即副本存在于单个设备上。

Full replication

部分复制#

部分复制描述的是数据有多个副本,并且每个副本都分片到多个设备上的过程。对于给定的数组值,通常有许多可能的方法来执行部分复制(注意:对于给定的数组形状,总是只有一个完全复制的 Sharding)。

以下是两个可能的例子。

在下面的第一个示例中,每个副本都分片到进程的两个本地设备上,总共有 4 个副本。这意味着每个进程都需要加载完整的全局数据,因为其本地设备将拥有数据的完整副本。

Partial replication - example 1

在下面的第二个示例中,每个副本仍然分片到两个设备上,但每对设备分布在两个不同的进程中。进程 0(粉色)和进程 1(黄色)都需要加载数据的第一行,而进程 2(绿色)和进程 3(蓝色)都需要加载数据的第二行

Partial replication - example 2

现在你已经了解了创建 jax.Array 的高级选项,接下来让我们将其应用于机器学习应用程序的数据加载。

数据并行#

纯数据并行(没有模型并行)中

  • 你在每个设备上复制模型;并且

  • 每个模型副本(即每个设备)接收不同的每个副本批次数据。

Data parallelism - example 1

将输入数据表示为单个 jax.Array 时,该 Array 包含此步骤中所有副本的数据(这称为全局批次),其中 jax.Array 的每个分片包含单个每个副本批次。你可以将其表示为跨所有设备的 1D 分片(请查看下面的示例)——换句话说,全局批次由所有每个副本批次沿批次轴连接在一起组成。

Data parallelism - example 2

应用此框架,你可能会得出结论,进程 0 应该获得全局批次的第一个四分之一(8 个中的 2 个),而进程 1 应该获得第二个四分之一,依此类推。

但是你怎么知道第一个四分之一是什么呢?你怎么确保进程 0 得到第一个四分之一呢?幸运的是,数据并行有一个非常重要的技巧,这意味着你无需回答这些问题,并且使整个设置变得更简单。

关于数据并行的一个重要技巧#

技巧在于你不需要关心哪个副本获得哪个分片批次。因此,哪个进程加载批次并不重要。原因是,由于每个设备都对应一个执行相同操作的模型副本,因此全局批次中哪个设备获得哪个分片批次并不重要。

这意味着你可以自由地重新排列全局批次内的每个副本批次。换句话说,你可以自由地随机化每个设备获得哪个数据分片。

例如

Data parallelism - example 3

通常,如上所示重新排列 jax.Array 的数据分片不是一个好主意——你实际上是在置换 jax.Array 的值!然而,对于数据并行,全局批次顺序没有意义,你可以自由地重新排列全局批次中的每个副本批次,如前所述。

这简化了数据加载,因为它意味着每个设备只需要一个独立的每副本批次流,这在大多数数据加载器中可以通过为每个进程创建一个独立的管道并将生成的每进程批次分块为每副本批次来轻松实现。

Data parallelism - example 4

这是选项 2:合并的每个进程数据管道的一个实例。你也可以使用其他选项(例如本文档前面介绍的 0、1 和 3),但此选项相对简单且高效。

以下是使用 tf.data 实现此设置的示例:

import jax
import tensorflow as tf
import numpy as np

################################################################################
# Step 1: setup the Dataset for pure data parallelism (do once)
################################################################################
# Fake example data (replace with your Dataset)
ds = tf.data.Dataset.from_tensor_slices(
    [np.ones((16, 3)) * i for i in range(100)])

ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())

################################################################################
# Step 2: create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step). This can be used with batches
# produced by different data loaders as well!
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()

mesh = jax.make_mesh((jax.device_count(),), ('batch',))
sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('batch'))
global_batch_array = jax.make_array_from_process_local_data(
    sharding, per_process_batch)

数据 + 模型并行#

模型并行中,你将每个模型副本分片到多个设备上。如果你使用纯模型并行(没有数据并行)

  • 只有一个模型副本分片到所有设备上;并且

  • 数据(通常)在所有设备上完全复制。

本指南考虑你同时使用数据并行和模型并行的情况。

  • 你将多个模型副本中的每个副本分片到多个设备上;并且

  • 你将数据部分复制到每个模型副本上——同一模型副本中的每个设备获得相同的每个副本批次,而跨模型副本的设备获得不同的每个副本批次。

进程内的模型并行#

为了数据加载的目的,最简单的方法可能是将每个模型副本分片到单个进程的本地设备内。

在此示例中,我们切换到 2 个进程,每个进程有 4 个设备(而不是 4 个进程,每个进程 2 个设备)。考虑一个场景,其中每个模型副本都分片到单个进程的 2 个本地设备上。这将导致每个进程有 2 个模型副本,总共有 4 个模型副本,如下所示:

Data and model parallelism - example 1

在这里,输入数据再次表示为单个 jax.Array,带有一个 1D 分片,其中每个分片都是一个每副本批次,但有一个例外:

  • 与纯数据并行情况不同,你引入了部分复制,并创建了 1D 分片全局批次的 2 份副本。

  • 这是因为每个模型副本由需要每个副本批次副本的 2 个设备组成。

Data and model parallelism - example 2

将每个模型副本保持在单个进程内可以使事情更简单,因为你可以重用上面描述的纯数据并行设置,只是你还需要复制每个副本批次。

Data and model parallelism - example 3

注意

将每个副本批次复制到正确的设备上也非常重要!虽然关于数据并行性的那个非常重要的技巧意味着你不在乎哪个批次最终落在哪个副本上,但你确实关心单个副本只获得一个批次

例如,这是可以的:

Data and model parallelism - example 4

然而,如果你不小心将每个批次加载到错误的本地设备上,你可能会意外地创建未复制的数据,尽管 Sharding(和并行策略)表明数据是复制的

Data and model parallelism - example 4

如果你在一个进程内意外地创建了一个 jax.Array,其数据本应被复制但实际未复制,JAX 会引发错误(尽管对于跨进程的模型并行性并非总是如此;请参阅下一节)。

以下是使用 tf.data 实现进程内模型并行和数据并行的示例:

import jax
import tensorflow as tf
import numpy as np

################################################################################
# Step 1: Set up the Dataset with a different data shard per-process (do once)
#         (same as for pure data parallelism)
################################################################################
# Fake example data (replace with your Dataset)
per_process_batches = [np.ones((16, 3)) * i for i in range(100)]
ds = tf.data.Dataset.from_tensor_slices(per_process_batches)

ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index())

################################################################################
# Step 2: Create a jax.Array of per-replica batches from the per-process batch
# produced from the Dataset (repeat every step)
################################################################################
# Grab just the first batch from the Dataset for this example
per_process_batch = ds.as_numpy_iterator().next()

num_model_replicas_per_process = 2 # set according to your parallelism strategy
num_model_replicas_total = num_model_replicas_per_process * jax.process_count()

# Create an example `Mesh` for per-process data parallelism. Make sure all devices
# are grouped by process, and then resize so each row is a model replica.
mesh_devices = np.array([jax.local_devices(process_idx)
                         for process_idx in range(jax.process_count())])
mesh_devices = mesh_devices.reshape(num_model_replicas_total, -1)
# Double check that each replica's devices are on a single process.
for replica_devices in mesh_devices:
  num_processes = len(set(d.process_index for d in replica_devices))
  assert num_processes == 1
mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"])

# Shard the data across model replicas. You don't shard across the
# data_parallelism mesh axis, meaning each per-replica shard will be replicated
# across that axis.
sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("model_replicas"))

global_batch_array = jax.make_array_from_process_local_data(
    sharding, per_process_batch)

跨进程的模型并行#

当模型副本分布在不同进程时,情况会变得更有趣,无论是:

  • 因为单个副本无法适应一个进程;或者

  • 因为设备分配不是那样设置的。

例如,回到之前 4 个进程、每个进程 2 个设备的设置,如果你这样分配设备到副本:

Model parallelism across processes - example 1

这与之前的进程内模型并行示例采用相同的并行策略——4 个模型副本,每个副本分片到 2 个设备上。唯一的区别是设备分配——每个副本的两个设备被分到不同的进程中,每个进程只负责每个副本批次的单个副本(但负责两个副本)。

像这样将模型副本分散到不同进程中可能看起来随意且不必要(在此示例中确实如此),但实际部署可能会采用这种设备分配方式,以便最好地利用设备之间的通信链路。

数据加载现在变得更加复杂,因为需要在进程之间进行额外的协调。在纯数据并行和进程内模型并行的情况下,重要的是每个进程只加载一个唯一的数据流。现在,某些进程必须加载相同的数据,而某些进程必须加载不同的数据。在上面的示例中,进程 02(分别为粉色和绿色)必须加载相同的 2 个每副本批次,而进程 13(分别为黄色和蓝色)也必须加载相同的 2 个每副本批次(但与进程 02 的批次不同)。

此外,重要的是每个进程不要混淆其 2 个每个副本的批次。虽然你不在乎哪个批次落在哪个副本上(数据并行性中非常重要的技巧),但你确实需要确保副本中的所有设备都获得相同的批次。例如,这将是糟糕的:

Model parallelism across processes - example 2

注意

截至 2023 年 8 月,JAX 无法检测 jax.Array 在进程间的分片是否应该被复制但实际未复制,并且在运行计算时会产生错误结果。所以务必小心,不要这样做!

要在每个设备上获得正确的每个副本批次,你需要将全局输入数据表示为以下 jax.Array

Model parallelism across processes - example 3