训练指南#

传统上,机器学习代码库依赖各种库来处理训练大型复杂模型时所需的繁琐记录和参数管理。虽然方便,但这些库可能会抽象掉 JAX 提供的关键功能和核心 API。因此,本指南旨在演示直接在 JAX 中编写简单且高性能的机器学习训练代码的最佳实践(即“配方”)。遵循下文记录的模式,将使您的机器学习工作负载能够最大限度地利用我们的编译器 (XLA) 来获得性能和可扩展性。大多数训练脚本大致遵循以下结构:

def train_loop(config: Config):
  record_writer = RecordWriter()
  train_state = init_train_state(config)
  train_state = jax.tree.map(jax.ref.new_ref, train_state)
  batch = iter(get_dataset_on_device(config))
  for step in range(config.num_train_steps):
    metrics = train_step(config, train_state, next(batch))
    record_writer({"step": step} | metrics)

对于上述每一行代码,我们将解释最佳实践,并展示我们汇集的各种核心技术,使您能够在 JAX 中编写既简单又极具性能的代码。上述代码是一个自包含、功能完整的配套脚本的一部分。在该脚本中,我们初始化了一个 Vaswani 等人 (2017) 提出的 Transformer 解码器,定义了用于下一词预测的训练损失函数,并使用纯 JAX 实现了 Adam 优化器。其中的代码适用于 TPU、CPU 和 GPU,以及单机和多机系统。因此,我们将“设备”或“加速器”互换使用,指代 JAX 主要进行算术运算的硬件(无论是 TPU、GPU 还是 CPU),并将“宿主系统”指代仅使用宿主 CPU 执行的操作。在本指南中,为了简洁起见,我们将略过 JAX API 的许多方面。您可以随时查阅我们的 API 文档了解详情。然而,为了理解后续内容,必须深入探讨一个核心的 JAX 概念。

设备网格与分片#

JAX 采用单程序多数据 (SPMD) 并行模型。这意味着我们编写一个在多个设备上运行的单一程序,并使用注解来指定每个设备负责数据的哪一部分。实现这一点的两个主要概念是 jax.sharding.Meshjax.P

设备网格#

一个 jax.sharding.Mesh 是将我们所有的加速器排列成一个 NumPy ndarray,并为设备数组的轴添加字符串标签。使用数组的原因在于,它提供了一种非常方便的注解方式,用于描述数组应如何在设备间进行分区。在本介绍中,我们将使用有序字典的符号 [1],例如 {"x": 2, "y": 4} 指代形状为 (2, 4) 且轴标签为 "x""y" 的设备网格。为了对数组 param 进行分片,我们用 jax.P 对其进行修饰,这是一个包含 str | None 元素的元组,其长度与数组的维度相同。jax.P 指定了数组的哪些轴要在设备的哪些轴上进行分片。有关分片符号和分片计算的更详细说明,请参阅 分布式数组和自动并行化。一些常见的分片策略(如数据并行、完全分片数据并行和基本张量并行)将在 实现高性能 中介绍。

示例

假设我们有一个 {"x": 2, "y": 4} 的设备网格和一个形状为 (32, 64, 64, 128) 的数组 param。如果我们用 jax.P(None, “x”, “y”, None) 对此数组进行分片,最终会得到分布在各个设备上大小为 (32, 32, 16, 128) 的分片。None 表示该轴不应分片。JAX 会隐式广播尾随轴,因此使用 jax.P(None, “x”, “y”) 可以更简洁地实现相同的分片。因此,完全复制数组(任意维度)的简写为 jax.P()

示例

当更高级的网格几何结构与设备的通信层级对齐时,使用起来会很方便。主机间通信通常比加速器间通信慢。假设我们有两台主机,每台主机连接八个 GPU。我们可以将设备排列成 {"host": 2, "gpu": 8} 的网格。然后我们可以按如下方式分片参数:

param = jnp.zeros((256, 192), out_sharding=jax.P("gpu", None))

param 的整体将被复制两次,但在每个主机内部,它将分布在八个本地连接的 GPU 上,每个 GPU 在 HBM 中存储一个形状为 (32, 192) 的分片。这对于 完全分片数据并行 (FSDP) 特别有用。

训练状态初始化#

@jax.jit
def init_train_state(config: Config) -> dot_dict:
  train_state = dot_dict()
  train_state.params = init_param_state(config)
  train_state.opt = jax.tree.map(init_adam_state, train_state.params)
  return train_state

在开始之前,我们需要做的第一件事是设置训练状态。训练状态(不出所料)封装了训练过程中所有有状态的方面。这通常至少包括模型参数和优化器状态。我们构建此函数的方式(尽管您可以选择其他方式)是:

  1. 创建一个嵌套字典序列来存放模型参数,然后

  2. 对这些参数使用 jax.tree.map(),生成一组类似的嵌套字典来存放相应的优化器状态。(下文有更多说明。)

参数初始化#

@jax.jit
def init_train_state(config: Config) -> dot_dict:
  train_state = dot_dict()
  train_state.params = init_param_state(config)
  train_state.opt = jax.tree.map(init_adam_state, train_state.params)
  return train_state

为了初始化参数,我们构建了一系列与神经网络语义部分相对应的嵌套字典。如果我们使用像 PyTorch 或 Flax 这样的基于层的库,它们可能对应于神经网络层。对于这个例子,我们实际上可以用一个完全扁平的字典来处理,但嵌套的方法不仅方便与 JAX 中的某些 API 配合使用,也有助于构建我们的代码。

def init_param_state(config: Config) -> dot_dict:
  root_key = jax.random.key(config.param_seed)
  key = map(ft.partial(jax.random.fold_in, root_key), it.count())
  zero_init = jax.nn.initializers.constant(0.0)
  he_init = jax.nn.initializers.he_normal(1, 1)
  dtype = config.dtype

  params = dot_dict(
    pos_embed=zero_init(next(key), (config.seq_length, config.embed_dim), dtype, config.pos_embed),
    layers=dot_dict(),
  )
  params.embedding = he_init(next(key), (config.vocab_size, config.embed_dim), dtype, config.embed)
  params.linear_in = dot_dict(
    kernel=he_init(next(key), (1, config.embed_dim), dtype, config.in_kernel),
    bias=zero_init(next(key), (config.embed_dim,), dtype, config.in_bias),
  )
  params.linear_out = dot_dict(
    kernel=he_init(next(key), (config.embed_dim, config.vocab_size), dtype, config.out_kernel),
  )
  for layer in range(config.num_layers):
    qkv_shape = (3, config.embed_dim, config.num_heads, config.head_dim)
    out_shape = (config.num_heads, config.head_dim, config.embed_dim)
    params.layers[layer] = dot_dict(
      attention=dot_dict(
        qkv=he_init(next(key), qkv_shape, dtype, config.att_qkv),
        out=he_init(next(key), out_shape, dtype, config.att_out),
      ),
      mlp=dot_dict(
        in_kernel=he_init(next(key), (config.embed_dim, config.mlp_dim), dtype, config.mlp_in),
        out_kernel=he_init(next(key), (config.mlp_dim, config.embed_dim), dtype, config.mlp_out),
      ),
    )
  return params

我们的 get_param_state 函数利用了 jax.nn.initializers 中提供的 constanthe_normal 工厂。这些工厂返回一个初始化器,这是一个符合以下协议的函数:

class Initializer(Protocol):
    def __call__(self, key, shape, dtype, out_sharding) -> jax.Array:
        ...

JAX 的函数式风格要求显式处理所有随机性(参见 伪随机数),因此我们设置了一个小的迭代器来生成 PRNG 密钥。然后,为了构建我们的参数,我们在 params 嵌套字典中各自的位置初始化它们,提供来自 Config 类的参数形状、数据类型和分片信息。

注意

通过在此处指定分片,我们直接在设备网格中需要它们的正确设备上初始化每个参数的分片,从而避免了不必要的主机到设备传输;或者,对于无法放入系统内存的模型,避免了内存不足错误。

优化器初始化#

@jax.jit
def init_train_state(config: Config) -> dot_dict:
  train_state = dot_dict()
  train_state.params = init_param_state(config)
  train_state.opt = jax.tree.map(init_adam_state, train_state.params)
  return train_state

在设置优化器状态时,情况比构建模型参数要复杂一些。Adam 优化器要求我们为每个参数跟踪三个优化状态:munucount。其中最简单的是 count,它存储了我们执行的训练步数。这只是一个用于对 Adam 更新进行去偏置的标量。munu 状态将是与相应参数 param 具有相同形状、数据类型和分片的数组 [2]

def init_adam_state(param: jax.Array) -> dot_dict:
  adam_state = dot_dict(mu=jnp.zeros_like(param), nu=jnp.zeros_like(param), count=jnp.array(0))
  return adam_state

当我们使用 jax.tree.map() 时,它会遍历 train_state.params 中的项。对于每个参数,它都会创建一个对应的 Adam 状态,从而生成一个新的嵌套字典,映射 train_state.params 的结构。这个新结构中的每个叶子节点都包含相应参数的优化器状态。

训练步(函数式转换)#

@jax.jit
def train_step(config: Config, train_state: dot_dict, batch: dict) -> dict:
  def loss_fn(params):
    logits = model_apply(config, params, batch["observed_ids"])
    labels = jax.nn.one_hot(batch["target_ids"], config.vocab_size)
    return -(labels * jax.nn.log_softmax(logits)).mean()

  params = jax.tree.map(jax.ref.get, train_state.params)
  loss, grad = jax.value_and_grad(loss_fn)(params)
  jax.tree.map(ft.partial(adam_update, config), train_state.params, grad, train_state.opt)
  metrics = {"train_loss": loss}
  return metrics

训练步是计算模型相对于当前参数的梯度,并利用梯度和优化器更新参数的地方。要在 JAX 中执行此操作,我们定义模型的前向传播,然后利用 JAX 的函数式转换自动生成反向传播,用于计算梯度并执行更新。

模型前向传播#

def model_apply(config: Config, params: dot_dict, tokens: jax.Array) -> jax.Array:
  out = params.embedding.at[tokens].get(out_sharding=config.act_seq)
  out += params.pos_embed
  del tokens

  for layer in range(config.num_layers):
    block = params.layers[layer]
    att_skip = out  # 1 billion dollars in venture capital funding please
    qkv = jnp.einsum("bsd,3dkh->bs3kh", out, block.attention.qkv, out_sharding=config.act_att)
    out = jax.nn.dot_product_attention(qkv[:, :, 0, :], qkv[:, :, 1, :], qkv[:, :, 2, :], is_causal=True)
    out = jnp.einsum("bskh,khd->bsd", out, block.attention.out, out_sharding=config.act_seq)
    out += att_skip
    out *= jax.lax.rsqrt(jnp.linalg.norm(out, axis=-1, keepdims=True) + 1e-6)

    mlp_skip = out  # machine learning circa 1986
    out = jnp.einsum("bsd,dh->bsh", out, block.mlp.in_kernel, out_sharding=config.act_hidden)
    out = jax.nn.gelu(out)
    out = jnp.einsum("bsh,hd->bsd", out, block.mlp.out_kernel, out_sharding=config.act_seq)
    out += mlp_skip
    out *= jax.lax.rsqrt(jnp.linalg.norm(out, axis=-1, keepdims=True) + 1e-6)

  logits = jnp.einsum("bsd,dl->bsl", out, params.linear_out.kernel, out_sharding=config.act_seq)
  return logits

除了我们提供的 out_sharding 注解外,模型的前向传播过程大多很平常。这些注解声明了操作执行后应达到的结果分片。编译器利用这些激活分片,连同我们在初始化模型时提供的参数分片,动态插入集合通信原语,在设备之间传输参数和激活。通过选择良好的分片策略,我们可以获得高性能的训练(和推理)代码。我们将在题为 实现高性能 的章节中介绍一些适用于大多数用例的标准策略。有关支撑分片策略设计的原则的详细讨论,请参阅 扩展指南 (The Scaling Cookbook)

梯度与优化器更新#

@jax.jit
def train_step(config: Config, train_state: dot_dict, batch: dict) -> dict:
  def loss_fn(params):
    logits = model_apply(config, params, batch["observed_ids"])
    labels = jax.nn.one_hot(batch["target_ids"], config.vocab_size)
    return -(labels * jax.nn.log_softmax(logits)).mean()

  params = jax.tree.map(jax.ref.get, train_state.params)
  loss, grad = jax.value_and_grad(loss_fn)(params)
  jax.tree.map(ft.partial(adam_update, config), train_state.params, grad, train_state.opt)
  metrics = {"train_loss": loss}
  return metrics

为了计算梯度,我们定义训练损失。这是一个以参数为自变量的函数,返回一个标量,该标量汇总了我们的模型(使用当前 train_state 参数)对数据拟合的程度。

loss, grad = jax.value_and_grad(loss_fn)(params)

通过将此函数提供给 jax.value_and_grad(),我们将其转换为一个返回标量值和 loss_fnparams 处评估梯度的函数(即梯度)。由于我们已根据一系列嵌套字典定义了参数,梯度也将是一系列嵌套字典,映射参数结构。请记住,与参数不同,优化器状态包含一些额外的、更深层的嵌套字典,对应于每个参数的优化器状态。在阅读解释之前,请花点时间思考一下以下函数调用的语义可能是什么:

jax.tree.map(ft.partial(adam_update, config), train_state.params, grad, train_state.opt)

检查函数 adam_apply 的调用签名会给我们一些提示:

def adam_update(config: Config, param: jax.Ref, grad: jax.Array, adam_state: dot_dict):
  adam_state.mu[...] = (1 - config.beta_1) * adam_state.mu[...] + config.beta_1 * grad
  adam_state.nu[...] = (1 - config.beta_2) * adam_state.nu[...] + config.beta_2 * grad**2
  adam_state.count[...] += 1

  mu_hat = adam_state.mu[...] / (1 - config.beta_1 ** adam_state.count[...])
  nu_hat = adam_state.nu[...] / (1 - config.beta_2 ** adam_state.count[...])
  param[...] -= config.learning_rate * mu_hat / (jnp.sqrt(nu_hat + config.eps_root) + config.eps)

由于 train_state.params 是第一个参数,jax.tree.map() 使用其树结构来指导映射过程 [3]。这意味着 train_state.opt 仅被遍历到与 train_state.params 叶子相同的深度。因此,每个参数的优化器状态被作为一个完整的子树传入,这使我们能够在 adam_apply 中轻松访问给定 param 的所有相关状态(如 munu)。

提示

如果我们希望在模型中的不同参数上使用不同的优化算法和状态(或冻结某些参数),可以通过修改 adam_apply 的主体并将 jax.tree.map() 替换为 jax.tree_util.tree_map_with_path() 来实现,后者允许操作函数根据参数自定义其行为。

训练循环#

def train_loop(config: Config):
  record_writer = RecordWriter()
  train_state = init_train_state(config)
  train_state = jax.tree.map(jax.ref.new_ref, train_state)
  batch = iter(get_dataset_on_device(config))
  for step in range(config.num_train_steps):
    metrics = train_step(config, train_state, next(batch))
    record_writer({"step": step} | metrics)

在训练期间,我们必须协调数据在两个关键角色之间的流动:宿主系统和加速器。确保这些系统之间的顺畅交互是编写高性能训练代码的关键。Python 的 GIL 通常会在这里造成重大阻碍,但为了解决这个问题,JAX 采用的异步分发范式使这种协调变得易于实现。但是,为了利用这种范式,我们在构建训练步时需要留意代码的执行方式。

通过异步分发提升效率#

宿主系统执行的最重要任务之一是获取数据并将其放置在加速器上,以便加速器永远不需要等待数据。加速器在训练步之间处于空闲等待的时间被称为步进气泡 (step bubble)。我们可以利用异步分发来最小化步进气泡。让我们看看它在我们的训练循环中是如何工作的,暂时忽略有关 record_writer 的那一行。

for step in range(config.num_train_steps):
  metrics = train_step(config, train_state, next(batch))

当这段代码执行时,Python 首先会查询范围迭代器,获取 step(值为 0),然后调用 next(batch),这需要一些时间来检索批次数据。接着,调用 train_step。到目前为止,一切都很正常。

接下来发生的事情很有趣。由于 jax.jit() 修饰的调用是非阻塞的,train_step 的调用会立即返回给 Python 解释器。当计算在加速器上排队时,实际上还没有执行任何工作。Python 循环继续进行,推进步数计数器,并为下一次迭代调用 next(batch)。一旦第二次调用 train_step,其输入现在是上一次 JIT 调用中 train_state 的变体引用和一批新的数据。运行时非常智能,它看到为了执行第二次 train_step 调用,我们首先需要实现步骤 0train_state 结果以进行修改。因此,它触发了第一步的计算,关键的是,当这种情况发生时,train_step 再次立即返回,循环继续跳过。Python 现在会一直向前运行,直到在步骤 3 遇到 next(batch) 函数,该函数继续在 Python 中执行并加载数据,同时第一步训练步(这次是真正的)正在执行。就这样,我们可以在没有任何传统多处理的情况下,同时加载数据并在加速器上进行数学运算。 [4]

        ---
displayMode: compact
---
gantt
    title Synchronous Dispatch: No Overlap
    axisFormat %

    section Host
    next(batch) :gb0, 0, 1000s
    next(batch) :gb1, after ajc0, 1000s
    next(batch) :gb2, after ajc1, 1000s

    section Accelerator

    train_step 0 :ajc0, after gb0, 2000s
    train_step 1 :ajc1, after gb1, 2000s
    
        ---
displayMode: compact
---
gantt
    title JAX Asynchronous Dispatch: Host-Device Overlap
    axisFormat %

    section Host
    %% Task: id, name, start, duration_or_end
    next(batch) :gb0, 0, 1000s
    next(batch) :gb1, after gb0, 1000s
    next(batch) :gb2, after gb1, 1000s
    next(batch) :gb3, after jc0, 1000s
    next(batch) :gb4, after jc1, 1000s

    section Accelerator
    %% Task: id, name, start, duration_or_end
    train_step 0 :jc0, after gb1, 2000s
    train_step 1 :jc1, after jc0, 2000s
    train_step 2 :jc2, after jc1, 2000s
    

常见错误#

在 Python 中编写异步分发代码时,需要警惕两个主要错误,以免干扰我们精心设计的计算协调。

请求设备到主机的传输#

到目前为止,我们忽略了变量 metrics 会发生什么。事实上,如果保持挂起状态,什么都不会发生,我们将如宣传的那样实现良好的重叠。然而,通常情况下,我们希望观察训练步的遥测数据,例如当前损失、梯度统计等。假设我们要插入如下代码:

metrics = train_step(config, train_state, next(batch))
print({"step": step} | metrics)

循环不会继续,print 会导致将 metrics 中的任何设备端数组传输到主机。这会中断 Python 解释器,代码被迫同步执行,从而产生步进气泡。解决方案稍微有点反直觉:在每一步,我们收集上一步的遥测数据。

class RecordWriter:
  prev_metrics = None

  def __call__(self, cur_metrics: dict):
    self.prev_metrics, log_metrics = cur_metrics, self.prev_metrics
    if log_metrics is None:
      return
    print(*it.starmap("{}: {}".format, log_metrics.items()), sep="\t")

metrics = train_step(config, train_state, next(batch))

像这样的小辅助函数对于实现良好的重叠并充分利用宿主系统和加速器的资源至关重要。当然,这里的简单 print 语句可以替换为任何从加速器请求数据的 Python 操作。

中断加速器#

另一种浪费大量云端计算资金的常见方式是,无意中在训练步之外向加速器排队数学运算。假设我们正在使用余弦学习率调度。

def learning_rate(count, init_value: float = 1e-4, decay_steps: int = 10_000, alpha: float = 1e-6):
    cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * jnp.minimum(count, decay_steps) / decay_steps))
    return init_value * (1 - alpha) * cosine_decay

一种常见的模式是希望在收集的其他指标旁边可视化该调度。然而,即使我们使用了之前定义的巧妙的 record_writer 类,以下代码也会在加速器上创建一个气泡。

metrics = train_step(config, train_state, next(batch))
record_writer({"step": step, "learning_rate": learning_rate(step)} | metrics)

这是因为我们在计算中使用了 jax.numpy。当调用 jax.numpy.minimum() 时,Python 整数 step 会被提升为 jax.Array 并传输到加速器(主机到设备传输)。计算现在在我们的主要 train_step 之外的加速器上排队。为了 print 结果,该值必须传输回主机(设备到主机传输)。这种往返强迫加速器与主机同步,我们因为制造了性能气泡而浪费了金钱。避免此问题的两种方法是:对这些计算使用 NumPy,或者使用 jax.default_device() 上下文管理器。

metrics = train_step(config, train_state, next(batch))
with jax.default_device('cpu'):
  record_writer({"step": step, "learning_rate": learning_rate(step)} | metrics)

数据加载#

除了重叠数据的实际加载(即将数据从网络存储检索到主机)之外,JAX 还允许我们将数据本身的主机到设备传输与训练步的计算进行重叠。特殊函数 jax.device_put() 被精心设计为非阻塞的,异步执行,这使得在我们的训练步环境中使用它是完全没问题的。但是,有一个专门为此类数据加载任务设计的更方便的函数。在以下代码中,dataset 是一个普通的 Python 迭代器,产生一个批处理数据的 dict。通过使用 jax.make_array_from_process_local_data() 映射此迭代器,我们生成了一个新的迭代器。从这个新迭代器中生成数据将产生放置在设备上的数据,准备供我们的训练步使用。在内部,它将使用 jax.tree.map() 来创建 jax.Array 对象,并将它们排队传输到设备。只要数据批处理速度足够快,在 TPU 和 GPU 上,这些传输都将与训练步计算重叠。

def get_dataset_on_device(config: Config) -> Iterator[dict[str, jax.Array]]:
  datset = get_dataset(config)
  sharding = jax.P(config.mesh_axis_names)
  return map(ft.partial(jax.make_array_from_process_local_data, sharding), datset)

实现高性能#

在本节中,我们将描述对训练有用的三种主要模型并行形式。在训练期间,吞吐量至关重要;也就是说,我们希望最大化每秒的平均操作次数。这与推理不同,推理的目标是通过确保所有操作在尽可能短的时间内完成来最小化延迟。以吞吐量作为训练的终极目标,本节介绍了三种主要的训练期间分片策略。对于每种策略,我们概述了实现它的 JAX 分片,并描述了涉及的通信原语,以便在研究程序跟踪时,您有地标可以寻找,以确认程序是否按预期运行。我们在下述代码块中定义的分片变量对应于它们在初始化模型前向传播中的使用。但在配套脚本中,这些以及训练代码的其他方面都是使用全局 Config 类方便设置的。

@jax.tree_util.register_static
@dataclass(kw_only=True, frozen=True)
class Config:
  mesh_axis_names: tuple[str, ...] = ("fsdp",)
  mesh_shape: tuple[int, ...] = (8,)
  seq_length: int = 128

  num_train_steps: int = 10**6
  host_batch_size: int = 16
  learning_rate: float = 1e-4
  beta_1: float = 0.9
  beta_2: float = 0.999
  eps: float = 1e-8
  eps_root: float = 0.0

  param_seed: int = 12738
  num_layers: int = 4
  embed_dim: int = 512
  mlp_dim: int = 512 * 4
  vocab_size: int = 2**8  # uint8 ascii encoding
  num_heads: int = 8
  head_dim: int = 128
  dtype: str = "bfloat16"

  embed: jax.P = jax.P(None, None)
  pos_embed: jax.P = jax.P(None, None)
  att_qkv: jax.P = jax.P(None, "fsdp", None, None)
  att_out: jax.P = jax.P("fsdp", None, None)
  mlp_in: jax.P = jax.P("fsdp", None)
  mlp_out: jax.P = jax.P(None, "fsdp")
  in_kernel: jax.P = jax.P(None, None)
  in_bias: jax.P = jax.P(None)
  out_kernel: jax.P = jax.P("fsdp", None)
  out_bias: jax.P = jax.P(None)

  act_ids: jax.P = jax.P("fsdp")
  act_seq: jax.P = jax.P("fsdp", None, None)
  act_att: jax.P = jax.P("fsdp", None, None, None)
  act_hidden: jax.P = jax.P("fsdp", None, None)

  def __post_init__(self):
    mesh = jax.make_mesh(self.mesh_shape, self.mesh_axis_names, len(self.mesh_shape) * (AxisType.Explicit,))
    jax.sharding.set_mesh(mesh)

数据并行#

数据并行是最常见且易于理解的并行形式。在这种方案中,每个加速器存储模型参数的完整副本,我们沿批次轴对激活进行分片,以拆分梯度计算。为了计算梯度,每个加速器执行单独的前向和反向传播。然后,在参数更新之前,XLA 插入一个 AllReduce 来共享更新并保持模型同步。

网格

mesh = jax.sharding.Mesh(jax.devices(), ('devices',))

参数分片

pos_embed = jax.P(None, None)
att_qkv = jax.P(None, None, None, None)
att_out = jax.P(None, None, None)
mlp_in = jax.P(None, None)
mlp_out = jax.P(None, None)
in_kernel = jax.P(None, None)
in_bias = jax.P(None)
out_kernel = jax.P(None, None)
out_bias = jax.P(None)

激活分片

act_ids = jax.P("devices")
act_seq = jax.P("devices", None, None)
act_att = jax.P("devices", None, None, None)
act_hidden = jax.P("devices", None, None)

完全分片数据并行 (FSDP)#

数据并行分片的缺点是必须在 HBM 中保留模型参数的多个完整冗余副本。对于小型模型,这是一种非常高性能的策略,但由于 HBM 供不应求,我们也需要对模型参数进行分片。在完全分片数据并行 (FSDP) 策略中,我们同时对激活和模型参数进行分片。现在,当发生前向传播时,参数会被一一反分片(通过 AllGather)还原为完整数组,然后再应用于激活。然而,这种反分片是短暂且临时的,从而在 HBM 上实现了巨大的节省。在反向传播中,每个 AllGather 变为一个 ReduceScatter。然后在优化器更新处有一个最终的 ReduceScatter 来同步梯度。与数据并行相比,总通信量增加了 50%,但我们的 HBM 压力减少了(模型大小除以设备数量)。

网格

mesh = jax.make_mesh((128*4,), ("fsdp",))

参数分片

pos_embed = jax.P(None, None)
att_qkv = jax.P(None, "fsdp", None, None)
att_out = jax.P("fsdp", None, None)
mlp_in = jax.P("fsdp", None)
mlp_out = jax.P(None, "fsdp")
in_kernel = jax.P(None, None)
in_bias = jax.P(None)
out_kernel = jax.P("fsdp", None)
out_bias = jax.P(None)

激活分片

act_ids = jax.P("fsdp")
act_seq = jax.P("fsdp", None, None)
act_att = jax.P("fsdp", None, None, None)
act_hidden = jax.P("fsdp", None, None)

注意

虽然 FSDP 涉及比数据并行多得多的通信,但在实践中,我们能够将通信与计算重叠,从而隐藏通信,在极大改善 HBM 预算的情况下实现相同的吞吐量。

张量并行#

如果我们的模型足够大且结构合理,在我们的加速器上划分单个示例内的计算是有益的。以矩阵乘法为例,我们可以将大型矩阵乘法分散到两个或四个加速器上。这涉及显著更多的通信,因此该策略仅适用于具有极高算术强度的计算,例如超大型矩阵乘法。对于多头自注意力,我们选择沿头轴分片并复制序列轴,因为这提供了最自然的并行度。如果 MLP 足够大,我们也可以有效地对矩阵乘法进行分片。

网格

mesh = jax.make_mesh((128,4), ("fsdp", "tensor"))

参数分片

pos_embed = jax.P(None, "tensor")
att_qkv = jax.P(None, "fsdp", "tensor", None)
att_out = jax.P("fsdp", None, None)
mlp_in = jax.P("fsdp", "tensor")
mlp_out = jax.P("tensor", "fsdp")
in_kernel = jax.P(None, None)
in_bias = jax.P(None)
out_kernel = jax.P("fsdp", None)
out_bias = jax.P(None)

激活分片

act_ids = jax.P("fsdp")
act_seq = jax.P("fsdp", None, None)
act_att = jax.P("fsdp", None, "tensor", None)
act_hidden = jax.P("fsdp", None, "tensor")