JAX 训练食谱#

传统上,机器学习代码库依赖于各种库来执行训练大型、复杂模型所需的许多记账和参数处理工作。虽然方便,但这些库可能会抽象化 JAX 提供的核心功能和 API。因此,本食谱的目的是直接在 JAX 中演示编写简单但高性能的机器学习训练代码的最佳实践(或“食谱”)。遵循下面记录的模式将使您的机器学习工作负载能够最大程度地利用我们的编译器 (XLA) 来提高性能和可追溯性。大多数训练脚本大致遵循以下结构:

def train_loop(config: Config):
  record_writer = RecordWriter()
  train_state = init_train_state(config)
  train_state = jax.tree.map(jax.ref.new_ref, train_state)
  batch = iter(get_dataset_on_device(config))
  for step in range(config.num_train_steps):
    metrics = train_step(config, train_state, next(batch))
    record_writer({"step": step} | metrics)

对于上面的每一行代码,我们将解释最佳实践,并展示我们为赋能您在 JAX 中编写简单、但性能极佳的代码而集成的核心技术。上面的代码是一个完全独立、功能完整的配套脚本的片段,在该脚本中我们初始化了一个Vaswani 等人 (2017) 的 Transformer 解码器,定义了下一个 token 预测的训练损失,以及Adam 优化器,全部使用纯 JAX 实现。其中的代码适用于 TPU、CPU 和 GPU,以及单主机和多主机系统。因此,我们将术语设备加速器用于相互替换地指代 JAX 主要执行算学的硬件——无论是 TPU、GPU 还是 CPU——并将主机系统用于仅使用主机 CPU 执行的操作。在本指南中,为了方便起见,我们将忽略许多 JAX API 的细节。这些 API 您可以根据自己的喜好在我们的 API 文档中进行查阅。然而,对于后续许多内容的一致性,有一个核心的 JAX 概念是必须详细探讨的。

设备网格与分片#

JAX 采用单程序多数据 (SPMD) 并行模型。这意味着我们编写一个在多个设备上运行的单个程序,并使用注解来指定每个设备负责数据的哪一部分。这其中的两个主要概念是jax.sharding.Meshjax.P

设备网格#

一个jax.sharding.Mesh 是我们将所有加速器排列成一个 NumPy ndarray,并为设备数组的轴(axes)提供字符串标签。使用数组的原因是它可以非常方便地注解数组如何在设备之间进行分区。在本介绍中,我们将使用有序字典[1] 的表示法,因此 {"x": 2, "y": 4} 表示一个形状为 (2, 4)、带有已标记轴 "x""y" 的设备网格。要对一个数组 param 进行分片,我们使用 jax.P 来注解它,jax.P 是一个元组,包含与数组维度数量相同的 str | None 元素。jax.P 指定了我们数组的哪些轴要被分片到设备的哪些轴上。关于分片和分片计算的表示法的更详细说明,请参见并行编程入门。一些常见的 JAX 分片策略,如数据并行、完全分片数据并行和基本张量并行,将在实现高性能部分介绍。

示例

假设我们有一个设备网格 {"x": 2, "y": 4} 和一个形状为 (32, 64, 64, 128) 的数组 param。如果我们使用 jax.P(None, “x”, “y”, None) ` 来分片这个数组,我们会得到大小为 ``(32, 32, 16, 128)` 的分片,分布在各个设备上。None 表示一个轴不应被分片。JAX 会隐式地广播尾随轴,因此可以使用 jax.P(None, “x”, “y”) 更简洁地实现相同的分片。因此,完全复制的数组(无论维度如何)的简写是 jax.P()

示例

更高级的网格几何结构在与我们设备的通信层次结构对齐时会很方便。主机到主机的通信通常比加速器到加速器的通信慢。假设我们有两个主机,每个主机连接八个 GPU。可以将设备排列成一个形状为 {"host": 2, "gpu": 8} 的网格。然后,我们可以对参数进行如下分片:

param = jnp.zeros((256, 192), out_sharding=jax.P("gpu", None))

整个 param 将被复制两次,但在每个主机内部,它将被分布到八个本地连接的 GPU 上,每个 GPU 在 HBM 中存储形状为 (32, 192) 的分片。这对于完全分片数据并行 (FSDP)特别有用。

训练状态初始化#

@jax.jit
def init_train_state(config: Config) -> dot_dict:
  train_state = dot_dict()
  train_state.params = init_param_state(config)
  train_state.opt = jax.tree.map(init_adam_state, train_state.params)
  return train_state

在开始之前,我们需要做的第一件事是设置训练状态。训练状态(顾名思义)封装了训练过程的所有有状态方面。这通常至少包括模型参数和优化器状态。我们构建此函数的方式(尽管您也可以选择其他方式)是:

  1. 创建一系列嵌套字典来存储模型参数,然后

  2. jax.tree.map() 遍历这些参数,生成一个类似的嵌套字典集合来存储相应的优化器状态。(更多内容请参见下文。)

参数初始化#

@jax.jit
def init_train_state(config: Config) -> dot_dict:
  train_state = dot_dict()
  train_state.params = init_param_state(config)
  train_state.opt = jax.tree.map(init_adam_state, train_state.params)
  return train_state

为了初始化我们的参数,我们构建了一系列嵌套字典,它们对应于神经网络的语义部分。如果我们使用像 PyTorch 或 Flax 这样的基于层的库,这些可能对应于神经网络层。在本例中,我们实际上可以使用一个完全展平的字典,但嵌套方法既方便与 JAX 中的一些 API 配合使用,也有助于组织我们的代码。

def init_param_state(config: Config) -> dot_dict:
  root_key = jax.random.key(config.param_seed)
  key = map(ft.partial(jax.random.fold_in, root_key), it.count())
  zero_init = jax.nn.initializers.constant(0.0)
  he_init = jax.nn.initializers.he_normal(1, 1)
  dtype = config.dtype

  params = dot_dict(
    pos_embed=zero_init(next(key), (config.seq_length, config.embed_dim), dtype, config.pos_embed),
    layers=dot_dict(),
  )
  params.embedding = he_init(next(key), (config.vocab_size, config.embed_dim), dtype, config.embed)
  params.linear_in = dot_dict(
    kernel=he_init(next(key), (1, config.embed_dim), dtype, config.in_kernel),
    bias=zero_init(next(key), (config.embed_dim,), dtype, config.in_bias),
  )
  params.linear_out = dot_dict(
    kernel=he_init(next(key), (config.embed_dim, config.vocab_size), dtype, config.out_kernel),
  )
  for layer in range(config.num_layers):
    qkv_shape = (3, config.embed_dim, config.num_heads, config.head_dim)
    out_shape = (config.num_heads, config.head_dim, config.embed_dim)
    params.layers[layer] = dot_dict(
      attention=dot_dict(
        qkv=he_init(next(key), qkv_shape, dtype, config.att_qkv),
        out=he_init(next(key), out_shape, dtype, config.att_out),
      ),
      mlp=dot_dict(
        in_kernel=he_init(next(key), (config.embed_dim, config.mlp_dim), dtype, config.mlp_in),
        out_kernel=he_init(next(key), (config.mlp_dim, config.embed_dim), dtype, config.mlp_out),
      ),
    )
  return params

我们的 get_param_state 函数利用了 jax.nn.initializers 中提供的 constanthe_normal 工厂。这些工厂返回一个初始化器,它是一个符合以下协议的函数:

class Initializer(Protocol):
    def __call__(self, key, shape, dtype, out_sharding) -> jax.Array:
        ...

JAX 的函数式风格需要显式处理所有随机性(即伪随机数),因此我们设置了一个小的迭代器,它会产生 PRNG 密钥。然后,为了构建我们的参数,我们在 params 嵌套字典的相应位置初始化它们,提供来自 Config 类的参数形状、dtype 和分片。

注意

通过在此处指定分片,我们直接在设备网格中正确的位置上初始化每个参数的每个分片,从而避免了不必要的设备到主机传输,或者在模型不适合系统内存的情况下,避免了内存不足的错误。

优化器初始化#

@jax.jit
def init_train_state(config: Config) -> dot_dict:
  train_state = dot_dict()
  train_state.params = init_param_state(config)
  train_state.opt = jax.tree.map(init_adam_state, train_state.params)
  return train_state

在设置优化器状态方面,事情不像构建模型参数那样简单。 Adam 优化器要求我们为每个参数保留三个优化状态:munucount。其中最简单的是 count,它存储了我们执行的训练步数。这只是一个用于去偏 Adam 更新的标量。 munu 状态将是与相应的参数 param 具有相同形状、dtype 和分片的数组[2]

def init_adam_state(param: jax.Array) -> dot_dict:
  adam_state = dot_dict(mu=jnp.zeros_like(param), nu=jnp.zeros_like(param), count=jnp.array(0))
  return adam_state

当我们使用 jax.tree.map() 时,它会遍历 train_state.params 中的项。对于每个参数,它会创建一个相应的 Adam 状态,从而生成一个与 train_state.params 结构相匹配的新嵌套字典。这个新结构中的每个叶子都包含相应参数的优化器状态。

训练步(函数式转换)#

@jax.jit
def train_step(config: Config, train_state: dot_dict, batch: dict) -> dict:
  def loss_fn(params):
    logits = model_apply(config, params, batch["observed_ids"])
    labels = jax.nn.one_hot(batch["target_ids"], config.vocab_size)
    return -(labels * jax.nn.log_softmax(logits)).mean()

  params = jax.tree.map(jax.ref.get, train_state.params)
  loss, grad = jax.value_and_grad(loss_fn)(params)
  jax.tree.map(ft.partial(adam_update, config), train_state.params, grad, train_state.opt)
  metrics = {"train_loss": loss}
  return metrics

训练步是我们计算模型相对于当前参数的梯度,并使用梯度和优化器来更新参数的地方。为了在 JAX 中做到这一点,我们定义了模型的正向传播,然后利用 JAX 的函数式转换自动生成反向传播,我们用它来计算梯度并执行更新。

模型前向传播#

def model_apply(config: Config, params: dot_dict, tokens: jax.Array) -> jax.Array:
  out = params.embedding.at[tokens].get(out_sharding=config.act_seq)
  out += params.pos_embed
  del tokens

  for layer in range(config.num_layers):
    block = params.layers[layer]
    att_skip = out  # 1 billion dollars in venture capital funding please
    qkv = jnp.einsum("bsd,3dkh->bs3kh", out, block.attention.qkv, out_sharding=config.act_att)
    out = jax.nn.dot_product_attention(qkv[:, :, 0, :], qkv[:, :, 1, :], qkv[:, :, 2, :], is_causal=True)
    out = jnp.einsum("bskh,khd->bsd", out, block.attention.out, out_sharding=config.act_seq)
    out += att_skip
    out *= jax.lax.rsqrt(jnp.linalg.norm(out, axis=-1, keepdims=True) + 1e-6)

    mlp_skip = out  # machine learning circa 1986
    out = jnp.einsum("bsd,dh->bsh", out, block.mlp.in_kernel, out_sharding=config.act_hidden)
    out = jax.nn.gelu(out)
    out = jnp.einsum("bsh,hd->bsd", out, block.mlp.out_kernel, out_sharding=config.act_seq)
    out += mlp_skip
    out *= jax.lax.rsqrt(jnp.linalg.norm(out, axis=-1, keepdims=True) + 1e-6)

  logits = jnp.einsum("bsd,dl->bsl", out, params.linear_out.kernel, out_sharding=config.act_seq)
  return logits

除了我们提供的 out_sharding 注解外,模型的前向传播大部分都很普通。这些注解声明了操作执行后结果的分片方式。编译器使用这些激活分片以及我们在初始化模型时提供的参数分片,动态插入通信集合操作,以传输参数和激活数据到设备之间。通过选择好的分片策略,我们可以实现高性能的训练(和推理)代码。我们将在标题为实现高性能的部分介绍一些标准策略。有关分片策略设计原理的详细讨论,请参见The Scaling Cookbook

梯度与优化器更新#

@jax.jit
def train_step(config: Config, train_state: dot_dict, batch: dict) -> dict:
  def loss_fn(params):
    logits = model_apply(config, params, batch["observed_ids"])
    labels = jax.nn.one_hot(batch["target_ids"], config.vocab_size)
    return -(labels * jax.nn.log_softmax(logits)).mean()

  params = jax.tree.map(jax.ref.get, train_state.params)
  loss, grad = jax.value_and_grad(loss_fn)(params)
  jax.tree.map(ft.partial(adam_update, config), train_state.params, grad, train_state.opt)
  metrics = {"train_loss": loss}
  return metrics

为了计算梯度,我们定义了训练损失。这是一个参数的函数,它返回一个标量,总结了我们的模型在当前 train_state 参数下对数据的解释程度。

loss, grad = jax.value_and_grad(loss_fn)(params)

通过将此函数提供给 jax.value_and_grad(),我们将其转换为一个函数,该函数返回 loss_fnparams 处计算的值和梯度(即梯度)。由于我们已将参数定义为一系列嵌套字典,因此梯度也将是反映参数的嵌套字典系列。请记住,与参数不同,优化器状态包含一些额外的、更深的嵌套字典,对应于每个参数的优化器状态。在阅读解释之前,请花点时间思考以下函数调用的语义可能是什么:

jax.tree.map(ft.partial(adam_update, config), train_state.params, grad, train_state.opt)

查看函数 adam_apply 的调用签名可以给我们一个提示:

def adam_update(config: Config, param: jax.Ref, grad: jax.Array, adam_state: dot_dict):
  adam_state.mu[...] = (1 - config.beta_1) * adam_state.mu[...] + config.beta_1 * grad
  adam_state.nu[...] = (1 - config.beta_2) * adam_state.nu[...] + config.beta_2 * grad**2
  adam_state.count[...] += 1

  mu_hat = adam_state.mu[...] / (1 - config.beta_1 ** adam_state.count[...])
  nu_hat = adam_state.nu[...] / (1 - config.beta_2 ** adam_state.count[...])
  param[...] -= config.learning_rate * mu_hat / (jnp.sqrt(nu_hat + config.eps_root) + config.eps)

由于 train_state.params 是第一个参数,jax.tree.map() 使用它的树结构来指导映射过程。[#prefix_tree]_ 这意味着 train_state.opt 只会遍历到 train_state.params 的叶子节点。因此,每个参数的优化器状态将作为完整的子树传入,这使得我们可以在 adam_apply 中轻松访问给定 param 的所有相关状态(如 munu)。

提示

如果我们希望在模型中使用不同的优化算法和状态对不同的参数(或冻结某些参数),我们可以通过修改 adam_apply 的主体,并将 jax.tree.map() 替换为 jax.tree_util.tree_map_with_path() 来实现这一点,后者允许操作函数根据参数自定义其行为。

训练循环#

def train_loop(config: Config):
  record_writer = RecordWriter()
  train_state = init_train_state(config)
  train_state = jax.tree.map(jax.ref.new_ref, train_state)
  batch = iter(get_dataset_on_device(config))
  for step in range(config.num_train_steps):
    metrics = train_step(config, train_state, next(batch))
    record_writer({"step": step} | metrics)

在训练过程中,我们必须协调主机系统和加速器之间的**数据流**。确保这些系统之间的顺畅交互是编写高性能训练代码的关键。Python 的全局解释器锁 (GIL) 通常会在此处构成重大障碍,但为了规避这一点,JAX 采用的异步分派范式使得这种协调很容易实现。但是,为了利用这种范式,我们在构建训练步时需要注意代码的执行方式。

通过异步分派实现效率#

主机系统执行的最重要的任务之一是获取数据并将其放到加速器上,以便加速器永远不会等待数据。加速器在训练步之间空闲等待的时间被称为“步泡”。我们可以利用异步分派来最小化步泡。让我们看看我们的训练循环是如何工作的,暂时忽略有关 record_writer 的那一行。

for step in range(config.num_train_steps):
  metrics = train_step(config, train_state, next(batch))

当这段代码执行时,Python 首先会查询 range 迭代器,获取 step(值为 0),然后调用 next(batch),这需要一些时间来检索批次。然后,调用 train_step。到目前为止,一切都很正常。

接下来发生的事情很有趣。因为jax.jit() 装饰的调用是非阻塞的,所以对 train_step 的调用会立即返回到 Python 解释器。虽然计算已在加速器上排队,但实际上尚未执行任何工作。Python 循环继续,推进步计数器,并调用下一个迭代的 next(batch)。一旦做出第二次 train_step 调用,其输入就变成了上一个 JIT 调用中 train_state 的变异引用和一个新的数据批次。运行时很智能,它看到为了执行第二个 train_step 调用,我们首先需要实现第 0 步的 train_state 结果来执行变异。因此,它会启动第一步的计算,而且至关重要的是,在此过程中,train_step 再次立即返回,循环再次跳过。Python 现在继续执行,直到遇到第 3 步的 next(batch) 函数,该函数在 Python 中执行,加载数据,同时第一个训练步正在执行(这次是真实的)。就这样,我们可以同时加载数据并在加速器上进行数学运算,而无需任何传统的进程间通信。 [4]

        ---
displayMode: compact
---
gantt
    title Synchronous Dispatch: No Overlap
    axisFormat %

    section Host
    next(batch) :gb0, 0, 1000s
    next(batch) :gb1, after ajc0, 1000s
    next(batch) :gb2, after ajc1, 1000s

    section Accelerator

    train_step 0 :ajc0, after gb0, 2000s
    train_step 1 :ajc1, after gb1, 2000s
    
        ---
displayMode: compact
---
gantt
    title JAX Asynchronous Dispatch: Host-Device Overlap
    axisFormat %

    section Host
    %% Task: id, name, start, duration_or_end
    next(batch) :gb0, 0, 1000s
    next(batch) :gb1, after gb0, 1000s
    next(batch) :gb2, after gb1, 1000s
    next(batch) :gb3, after jc0, 1000s
    next(batch) :gb4, after jc1, 1000s

    section Accelerator
    %% Task: id, name, start, duration_or_end
    train_step 0 :jc0, after gb1, 2000s
    train_step 1 :jc1, after jc0, 2000s
    train_step 2 :jc2, after jc1, 2000s
    

常见错误#

在编写 Python 中的异步分派代码时,有两个主要错误需要注意,以免干扰我们精心编排的计算。

请求设备到主机传输#

到目前为止,我们忽略了变量 metrics 的处理。事实上,如果将其悬空,则什么都不会发生,并且我们将获得广告中所宣传的良好重叠。然而,大多数情况下,我们希望从训练步中观察遥测数据,例如当前损失、梯度统计等。假设我们插入了如下代码:

metrics = train_step(config, train_state, next(batch))
print({"step": step} | metrics)

循环并没有继续前进,而是 print 操作会引起 metrics 中任何在设备上的数组的设备到主机传输。这会中断 Python 解释器,代码被迫同步执行,从而产生一个步泡。解决方案有点违反直觉:在每一步,我们收集前一步的遥测数据。

class RecordWriter:
  prev_metrics = None

  def __call__(self, cur_metrics: dict):
    self.prev_metrics, log_metrics = cur_metrics, self.prev_metrics
    if log_metrics is None:
      return
    print(*it.starmap("{}: {}".format, log_metrics.items()), sep="\t")

metrics = train_step(config, train_state, next(batch))

像这样的小辅助函数对于实现良好的重叠并充分利用我们的主机系统和加速器的资源至关重要。当然,这里的简单 print 语句可以替换为任何从加速器请求数据的 Python 操作。

中断加速器#

浪费大量云计算资金的另一个常见方式是无意中在训练步之外将数学运算排队到加速器上。假设我们正在使用余弦学习率调度器。

def learning_rate(count, init_value: float = 1e-4, decay_steps: int = 10_000, alpha: float = 1e-6):
    cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * jnp.minimum(count, decay_steps) / decay_steps))
    return init_value * (1 - alpha) * cosine_decay

一个常见的模式是希望将调度器与我们正在收集的其他指标一起可视化。然而,即使我们使用了之前定义的巧妙的 record_writer 类,以下代码也会在加速器上创建一个步泡。

metrics = train_step(config, train_state, next(batch))
record_writer({"step": step, "learning_rate": learning_rate(step)} | metrics)

这是因为我们在计算中使用了jax.numpy。当调用jax.numpy.minimum()时,Python 整数 step 会被提升为jax.Array并传输到加速器(主机到设备传输)。计算现在在加速器上排队,位于我们的主 train_step 之外。要 print 结果,值必须传输回主机(设备到主机传输)。这种往返迫使加速器与主机同步,我们通过创建性能气泡浪费了金钱。避免这种情况的两种方法是使用 NumPy 进行这些计算,或者使用 jax.default_device() 上下文管理器。

metrics = train_step(config, train_state, next(batch))
with jax.default_device('cpu'):
  record_writer({"step": step, "learning_rate": learning_rate(step)} | metrics)

数据加载#

除了重叠数据本身的实际加载(即将其从网络存储检索到主机)之外,JAX 还允许我们将数据的主机到设备传输与训练步的计算重叠。特殊函数 jax.device_put() 的设计使其成为非阻塞的、异步执行的,因此在我们的训练步中使用它是完全可以的。然而,有一个更方便的函数专门用于数据加载任务。在以下代码中,dataset 是一个普通的 Python 迭代器,它会产生一个包含批处理数据的 dict。通过使用 jax.make_array_from_process_local_data() 映射此迭代器,我们生成了一个新的迭代器。从这个新迭代器产生数据将生成放置在设备上的数据,已准备好供我们的训练步使用。在内部,它将使用 jax.tree.map() 来创建 jax.Array 对象并将它们排队以传输到设备。前提是数据能够足够快地进行批处理,在 TPU 和 GPU 上,这些传输将与训练步的计算重叠。

def get_dataset_on_device(config: Config) -> Iterator[dict[str, jax.Array]]:
  datset = get_dataset(config)
  sharding = jax.P(config.mesh_axis_names)
  return map(ft.partial(jax.make_array_from_process_local_data, sharding), datset)

实现高性能#

在本节中,我们将描述训练中有用的三种主要的模型并行形式。在训练过程中,吞吐量至关重要;也就是说,我们希望最大化每秒操作的平均数量。这与推理不同,推理的目标是通过确保所有操作在尽可能短的时间内完成来最小化延迟。考虑到吞吐量是我们训练的最终目标,本节将介绍训练期间进行分片的三个主要策略。对于每种策略,我们都将概述实现它的 JAX 分片,并描述涉及的通信集合,以便在研究程序跟踪时,您会有一些地标来确认程序是否按预期运行。下面代码块中定义的 JAX 分片变量对应于它们在初始化模型前向传播中的用法。但在配套脚本中,这些以及训练代码的其他方面都通过全局 Config 类方便地设置。

@jax.tree_util.register_static
@dataclass(kw_only=True, frozen=True)
class Config:
  mesh_axis_names: tuple[str, ...] = ("fsdp",)
  mesh_shape: tuple[int, ...] = (8,)
  seq_length: int = 128

  num_train_steps: int = 10**6
  host_batch_size: int = 16
  learning_rate: float = 1e-4
  beta_1: float = 0.9
  beta_2: float = 0.999
  eps: float = 1e-8
  eps_root: float = 0.0

  param_seed: int = 12738
  num_layers: int = 4
  embed_dim: int = 512
  mlp_dim: int = 512 * 4
  vocab_size: int = 2**8  # uint8 ascii encoding
  num_heads: int = 8
  head_dim: int = 128
  dtype: str = "bfloat16"

  embed: jax.P = jax.P(None, None)
  pos_embed: jax.P = jax.P(None, None)
  att_qkv: jax.P = jax.P(None, "fsdp", None, None)
  att_out: jax.P = jax.P("fsdp", None, None)
  mlp_in: jax.P = jax.P("fsdp", None)
  mlp_out: jax.P = jax.P(None, "fsdp")
  in_kernel: jax.P = jax.P(None, None)
  in_bias: jax.P = jax.P(None)
  out_kernel: jax.P = jax.P("fsdp", None)
  out_bias: jax.P = jax.P(None)

  act_ids: jax.P = jax.P("fsdp")
  act_seq: jax.P = jax.P("fsdp", None, None)
  act_att: jax.P = jax.P("fsdp", None, None, None)
  act_hidden: jax.P = jax.P("fsdp", None, None)

  def __post_init__(self):
    mesh = jax.make_mesh(self.mesh_shape, self.mesh_axis_names, len(self.mesh_shape) * (AxisType.Explicit,))
    jax.sharding.set_mesh(mesh)

数据并行#

数据并行是最常见且易于理解的并行形式。在此方案中,每个加速器存储模型参数的完整副本,我们沿批次轴分片激活以分割梯度计算。为了计算梯度,每个加速器执行单独的前向和后向传播。然后,在更新参数之前,XLA 会插入一个 AllReduce 来共享更新并保持模型同步。

网格

mesh = jax.sharding.Mesh(jax.devices(), ('devices',))

参数分片

pos_embed = jax.P(None, None)
att_qkv = jax.P(None, None, None, None)
att_out = jax.P(None, None, None)
mlp_in = jax.P(None, None)
mlp_out = jax.P(None, None)
in_kernel = jax.P(None, None)
in_bias = jax.P(None)
out_kernel = jax.P(None, None)
out_bias = jax.P(None)

激活分片

act_ids = jax.P("devices")
act_seq = jax.P("devices", None, None)
act_att = jax.P("devices", None, None, None)
act_hidden = jax.P("devices", None, None)

完全分片数据并行 (FSDP)#

数据并行分片的缺点是我们必须在 HBM 中保留多个完整、冗余的模型参数副本。这对于小型模型来说是一种非常高性能的策略,但由于 HBM 供应紧张,我们也需要分片模型参数。在完全分片数据并行 (FSDP) 策略中,我们同时分片模型和参数。现在,随着前向传播的进行,参数一个接一个地(通过 AllGather)解开分片成为完整的数组,然后它们才应用于激活。然而,这种解分片是短暂的,因此可以大大节省 HBM。在后向传播中,每个 AllGather 变成一个 ReduceScatter。然后在优化器更新时有一个最终的 ReduceScatter 来同步梯度。与数据并行相比,总通信量增加了 50%,但 HBM 压力降低了模型大小除以设备数量。

网格

mesh = jax.sharding.Mesh(jax.devices(), ('fsdp',))

参数分片

pos_embed = jax.P(None, None)
att_qkv = jax.P(None, "fsdp", None, None)
att_out = jax.P("fsdp", None, None)
mlp_in = jax.P("fsdp", None)
mlp_out = jax.P(None, "fsdp")
in_kernel = jax.P(None, None)
in_bias = jax.P(None)
out_kernel = jax.P("fsdp", None)
out_bias = jax.P(None)

激活分片

act_ids = jax.P("fsdp")
act_seq = jax.P("fsdp", None, None)
act_att = jax.P("fsdp", None, None, None)
act_hidden = jax.P("fsdp", None, None)

注意

虽然 FSDP 需要比数据并行更多的通信,但实际上我们能够将通信与计算重叠,从而隐藏通信,并在大大改进的 HBM 预算下实现相同的吞吐量。

张量并行#

如果我们的模型足够大且结构合适,那么将单个示例内的计算分布到我们的加速器上就会变得有益。以矩阵乘法为例,我们可以将大型矩阵乘法分布到两个或四个加速器上。这会涉及显著更多的通信,因此这种策略仅适用于算术强度非常高的计算,例如极大型矩阵乘法。对于多头自注意力机制,我们选择沿头部分片,并保持序列轴的复制,因为这提供了最自然的并行度。如果 MLP 足够大,我们也可以高效地分片矩阵乘法。

网格

mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(128, 4), ("fsdp", "tensor"))

参数分片

pos_embed = jax.P(None, "tensor")
att_qkv = jax.P(None, "fsdp", "tensor", None)
att_out = jax.P("fsdp", None, None)
mlp_in = jax.P("fsdp", "tensor")
mlp_out = jax.P("tensor", "fsdp")
in_kernel = jax.P(None, None)
in_bias = jax.P(None)
out_kernel = jax.P("fsdp", None)
out_bias = jax.P(None)

激活分片

act_ids = jax.P("fsdp")
act_seq = jax.P("fsdp", None, None)
act_att = jax.P("fsdp", None, "tensor", None)
act_hidden = jax.P("fsdp", None, "tensor")