JAX 内存和主机卸载#

本教程提供了 JAX 中主机卸载技术的实践入门,重点关注

  • 激活卸载

  • 参数卸载

  • 优化器状态卸载

通过应用卸载策略,开发人员可以更好地管理内存资源并减轻设备上的内存压力。要有效地实现这些策略,理解 JAX 用于数据放置和移动的核心机制至关重要。

卸载的基础组件#

JAX 提供了几个关键组件来控制数据存储的位置以及如何在主机和设备内存之间移动数据。以下各节将探讨

  • 如何使用 sharding 指定数据分布

  • 如何控制主机和设备之间的内存放置

  • 如何在 jitted 函数中管理数据移动

NamedSharding 和内存类型#

NamedSharding 定义了数据如何在设备之间分布。它包括

  • 基本数据分布配置

  • memory_kind 参数用于指定内存类型(devicepinned_host

  • 默认情况下,memory_kind 设置为 device 内存

  • with_memory_kind 方法用于创建具有修改后的内存类型的新 sharding

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np

# Create mesh
# 1x1 mesh represents a single device with two named dimensions (x and y)
mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))

# Device sharding - partitions data along x and y dimensions
s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind="device")

# Host sharding - same partitioning but in pinned host memory
s_host = s_dev.with_memory_kind('pinned_host')

print(s_dev)   # Shows device memory sharding
print(s_host)  # Shows pinned host memory sharding
NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)
NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)

使用 device_put 进行数据放置#

jax.device_put() 是一个函数,它根据 sharding 规范显式地将数组传输到指定的内存位置。

# Create a 2x4 array
arr = jnp.arange(8.0).reshape(2, 4)

# Move arrays to different memory locations based on sharding objects
arr_host = jax.device_put(arr, s_host)  # Places in pinned host memory
arr_dev = jax.device_put(arr, s_dev)    # Places in device memory

# Verify memory locations
print(arr_host.sharding.memory_kind)  # Output: pinned_host
print(arr_dev.sharding.memory_kind)   # Output: device
pinned_host
device

输出 Sharding 控制#

Shardings 决定了数据如何在设备之间分割。JAX 提供了 out_shardings 来控制当输出数组离开 jitted 函数时它们的分割方式。

主要特点

  • 可以与输入 sharding 不同

  • 允许输出使用不同的内存类型

示例

设备输出 Sharding#

f = jax.jit(lambda x:x, out_shardings=s_dev)
out_dev = f(arr_host)
print("Result value of H2D: \n", out_dev)
Result value of H2D: 
 [[0. 1. 2. 3.]
 [4. 5. 6. 7.]]

将数据从主机移动到设备内存以供计算是主机卸载的本质。在此示例中,使用 jax.device_put() 来执行此传输,以优化性能。

# Instead of the lambda function, add_func can be defined explicitly
# move data to device before computation
def add_func(x):  # Move data to device and add one
  x = jax.device_put(x, s_dev)
  return x + 1

f = jax.jit(add_func, out_shardings=s_dev)
out_dev = f(arr_host)
print("Result value of H2D and add 1 in device memory: \n", out_dev)
Result value of H2D and add 1 in device memory: 
 [[1. 2. 3. 4.]
 [5. 6. 7. 8.]]

主机输出 Sharding#

f = jax.jit(lambda x: x, out_shardings=s_host)
out_host = f(arr_dev)      # Input arrays in the device memory while output arrays in the host memory
print("Result value of D2H: \n", out_host)
Result value of D2H: 
 [[0. 1. 2. 3.]
 [4. 5. 6. 7.]]

激活卸载#

在深入研究激活卸载之前,让我们先看一下基线代码。

这段代码实现了一个简单的神经网络,包含 10 层,每层由两个线性变换组成。代码演示了基本的内存使用模式,并为比较卸载优化技术提供了基础。

关键组件

  • 每层由两个顺序的线性操作组成

    1. 第一次乘法:x @ w1

    2. 第二次乘法:y @ w2

  • 使用 JAX 的 scan 操作构建的 10 层网络

  • 内存使用分析

  • 使用 JIT 编译进行梯度计算

要分析 JAX 中的内存使用情况,可以在已编译的函数上使用 jax.stages.Compiled.memory_analysis() 方法。这提供了有关计算过程中内存消耗的详细统计信息。关键指标包括临时内存大小、参数大小、输出大小和别名大小。要计算总内存使用量,请将临时、参数和输出大小相加,然后减去别名大小,以避免多次重复计算相同的内存。这提供了计算不同方面如何利用设备内存的概览。

# Initialize input and weights with small values (0.0001)
input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001  # Input matrix: 256 x 256
w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices
w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices

def two_layers(x, w):
  # Simple two-layer linear transformation
  w1, w2 = w
  y = x @ w1
  return y @ w2, None

def scanned(w, x):
  # Applies the layer function 10 times using JAX's scan operation
  # Input: w (tuple of weight matrices), x (input matrix)
  # Output: sum of the final layer's output
  result = jax.lax.scan(two_layers, x, w)[0]
  return jnp.sum(result)

# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(scanned))  # Apply JIT compilation to gradient computation

# Analyze memory usage
compiled_step = f.lower((w1, w2), input).compile()
compiled_stats = compiled_step.memory_analysis()

if compiled_stats is not None:
  # Calculate total memory usage including temporary storage, arguments, and outputs
  # Subtract alias size to avoid double-counting memory shared between different components
  total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
      + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
  print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB")
  print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB")
  print(f"Total size: {total/(1024**2):.2f} MB")

# Execute the function and print sample results
result = f((w1, w2), input)     # Execute the function with weights and input
print("Sample of results: ", result[0][0, 0, :5])
Temp size: 17.25 MB
Argument size: 20.25 MB
Total size: 57.50 MB
Sample of results:  [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]

有关激活卸载的详细介绍,请参阅 使用 jax.checkpoint (jax.remat) 进行梯度检查点 教程。激活卸载通过在正向传播后将中间激活移动到主机内存,并在需要进行梯度计算的反向传播期间将其带回到设备内存来帮助管理内存。

要有效地实现激活卸载,理解检查点名称和策略非常重要。以下是一个简单示例,展示了它们的工作原理

检查点名称#

函数 checkpoint_name() 允许在计算过程中为内存管理标记激活。这里有一个简单示例,指定了检查点名称 x

from jax.ad_checkpoint import checkpoint_name

def layer_name(x, w):
  w1, w2 = w
  x = checkpoint_name(x, "x")
  y = x @ w1
  return y @ w2, None

检查点名称帮助系统决定是否

  • 将激活保留在设备内存中,或者

  • 在计算过程中将其卸载到主机内存

这种模式在神经网络中很常见,其中多个变换按顺序应用于输入数据。

检查点策略#

此检查点策略实现了一种内存管理策略,可在计算过程中优化内存使用。它通过三种策略处理中间值来管理内存

  1. 在反向传播期间重新计算(默认行为)

  2. 存储在设备上

  3. 在正向传播后卸载到主机内存,并在反向传播期间加载回

from jax import checkpoint_policies as cp

policy = cp.save_and_offload_only_these_names(
    names_which_can_be_saved=[],          # No values stored on device
    names_which_can_be_offloaded=["x"],   # Offload activations labeled "x"
    offload_src="device",                 # Move from device memory
    offload_dst="pinned_host"             # To pinned host memory
)

jax.lax.scan() 在 JAX 中常用于处理顺序操作(如 RNN 或 Transformer)。它可以与 JAX 的重构集成以处理顺序数据。

关键组件

  • jax.remat() 使用 jax.remat() 创建层函数的可重构版本,并将检查点策略应用于层函数

  • prevent_cse=False 启用 XLA 的公共子表达式消除以获得更好的性能

  • jax.lax.scan() 沿着一个轴迭代可重构层

def scanned(w, x):
  remat_layer = jax.remat(layer_name,
                          policy=policy,     # Use our offloading policy
                          prevent_cse=False) # Allow CSE optimizations
  result = jax.lax.scan(remat_layer, x, w)[0]
  return jnp.sum(result)

# Initialize input and weights with small values (0.0001)
input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001  # Input matrix: 256 x 256
w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices
w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices

# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(scanned))  # Apply JIT compilation to gradient computation

# Analyze memory usage
compiled_step = f.lower((w1, w2), input).compile()
compiled_stats = compiled_step.memory_analysis()

if compiled_stats is not None:
  total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
      + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
  print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB")
  print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB")
  print(f"Total size: {total/(1024**2):.2f} MB")

result_activation = f((w1, w2), input)     # Execute the function with weights and input
# Verify numerical correctness
are_close = jnp.allclose(
    result_activation[0],    # Result from activation offloading only
    result[0],         # Result from both activation and parameter offloading
    rtol=1e-5,
    atol=1e-5
)
print(f"Results match within tolerance: {are_close}")
print("Sample of results: ", result_activation[0][0, 0, :5])
Temp size: 6.50 MB
Argument size: 20.25 MB
Total size: 46.75 MB
Results match within tolerance: True
Sample of results:  [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]

激活卸载将临时内存使用量从 17.25 MB 减少到 6.5 MB,而输入和输出参数的大小保持不变。总共节省了 10.75 MB。这是通过在正向传播后将激活 x 卸载到主机内存,并在反向传播之前将其加载回设备内存来实现的。

激活卸载总结#

激活卸载提供了一种强大的内存管理方式,适用于大型计算,通过

  • 使用检查点名称来标记特定的激活

  • 应用策略来控制激活的存储位置和方式

  • 支持常见的 JAX 模式,如 scan 操作

  • 当设备内存不足时,将选定的激活移动到主机内存

这种方法在处理大型模型时尤其有用,否则这些模型将超出设备内存容量。

参数卸载#

模型参数(也称为权重)可以在初始化期间卸载到主机内存,以优化设备内存使用。这是通过使用指定主机内存类型的 sharding 策略的 jax.jit() 来实现的。

虽然参数卸载和激活卸载是不同的内存优化技术,但以下示例演示了基于前面所示的激活卸载实现构建的参数卸载。

用于计算的参数放置#

与之前的 layer 函数不同,在此应用 jax.device_put() 来在矩阵乘法之前将参数 w1w2 移动到设备。这确保了参数在设备上可用于正向和反向传播。

请注意,激活卸载实现保持不变,使用相同的

  • 检查点名称 "x"

  • 检查点策略

  • scanned 函数结合了 jax.remat()jax.lax.scan()

使用主机卸载初始化参数#

在初始化期间,参数 w1w2 被放置在主机内存中,然后传递给 jax.jit() 函数 f,同时将 input 变量保留在设备上。

# Hybrid version: Both activation and parameter offloading
def hybrid_layer(x, w):
  # Move model parameters w1 and w2 to device memory via device_put
  w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)
  x = checkpoint_name(x, "x")  # Offload activation x to host memory
  y = x @ w1
  return y @ w2, None

def hybrid_scanned(w, x):
  remat_layer = jax.remat(hybrid_layer,     # Use hybrid_layer instead of layer
                          policy=policy,     # Use offloading policy
                          prevent_cse=False) # Allow CSE optimizations
  result = jax.lax.scan(remat_layer, x, w)[0]
  return jnp.sum(result)

# Move model parameters w1 and w2 to the host via device_put
# Initialize input and weights with small values (0.0001)
wh1 = jax.device_put(w1, s_host)
wh2 = jax.device_put(w2, s_host)

# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(hybrid_scanned))  # Apply JIT compilation to gradient computation

# Analyze memory usage
compiled_step = f.lower((wh1, wh2), input).compile()
compiled_stats = compiled_step.memory_analysis()

if compiled_stats is not None:
  total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
      + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
  print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB")
  print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB")
  print(f"Total size: {total / (1024**2):.2f} MB")

result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading

# Verify numerical correctness
are_close = jnp.allclose(
    result_activation[0],    # Result from activation offloading only
    result_both[0],         # Result from both activation and parameter offloading
    rtol=1e-5,
    atol=1e-5
)
print(f"Results match within tolerance: {are_close}")
Temp size: 4.75 MB
Argument size: 0.25 MB
Total size: 25.00 MB
Results match within tolerance: True

此实现演示了如何将模型参数与激活一起卸载到主机内存,可以显著减少设备内存使用。

内存分析#

基线内存使用

  • 输入张量:0.25 MB(256 × 256 × 4 字节)

  • 模型参数(w1,w2):每个 10 MB(256 × 1024 × 4 字节 ≈ 每层 1 MB × 10 层)

内存使用比较

  • 不进行参数卸载时的参数大小:20.25 MB(0.25 + 10 + 10)

  • 进行参数卸载时的参数大小:0.25 MB(仅保留输入)

  • 不进行激活卸载时的临时内存:17.25 MB

  • 进行激活卸载时的临时内存:6.50 MB

  • 进行激活和参数卸载时的临时内存:4.75 MB

关键优化#

  1. 参数卸载:将参数(w1,w2)移动到主机内存可将参数大小减少 20 MB(从 20.25 MB 减少到 0.25 MB)。

  2. 激活卸载:将激活移动到主机内存可将临时内存使用量减少 10.75 MB(从 17.25 MB 减少到 6.50 MB)。

  3. 混合策略:激活卸载的重构有助于避免将权重保留在设备上,并将临时内存使用量减少 1.75 MB(从 6.50 MB 减少到 4.75 MB)。如果没有它,JAX 会急于在反向传播期间将权重在设备上的副本保留下来。

结果#

总内存节省:33.5 MB(20 MB + 10.75 MB + 1.75 MB)

这种混合方法表明,参数和激活卸载协同工作,可在保持计算正确性的同时实现显著的内存减少。

参数卸载的局限性#

对于有效的参数管理,jax.lax.scan() 至关重要。使用显式 for 循环会导致参数持续占用设备内存,从而导致内存使用量与不进行参数卸载时相同。虽然 jax.lax.scan() 允许指定 scan 轴,但参数卸载目前仅在扫描轴 0 时有效。扫描其他轴会在编译期间生成一个 transpose 操作,然后才能将参数返回到设备,这很昂贵且并非在所有平台上都支持。

不同设备类型的卸载性能可能会有所不同。由于主机和设备之间的内存传输,它可能会降低性能,因此在设计优化策略时考虑这种权衡非常重要。

优化器状态卸载#

优化器状态卸载是一种内存管理技术,它将优化器状态存储在主机内存而不是设备内存中。当优化器状态很大时,这种方法特别有用,因为它减少了设备内存的使用。

一个使用 Adam 优化器的基本 JAX 实现可以作为起点,其中所有张量都存储在设备上。在引入优化器状态卸载之前,这将作为参考实现。

基本实现#

在本节中,我们将实现一个带有 Adam 优化器的简单模型。此实现有助于在探索优化器状态卸载之前建立基线行为。它对于理解大规模神经网络训练中的内存模式尤其有用。

在下面的代码示例中,包含了一个神经网络训练循环,以使用 JAX 和 Optax 的 Adam 优化器。该网络包含四个具有 GELU 激活函数的线性层,处理大小为 7168x7168 的大型矩阵。训练过程包括

  • 正向传播:输入流经四个层,每个层应用线性变换后接 GELU 激活

  • 损失计算:计算输出和输入之间的均方误差,加上 L2 正则化

  • 反向传播:使用自动微分计算梯度

  • 优化步骤:使用带有梯度裁剪的 Adam 优化器更新参数

代码使用 JIT 编译来优化性能,并包括内存使用分析来监控训练过程中所需的计算资源。内存分析提供了对优化步骤期间的临时内存使用、参数大小和总内存消耗的洞察。

import optax

DIM = 7168

# Initialize data and parameter w1, w2, w3 and w4
input = jnp.ones((DIM, DIM))
params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}

# Initialize optimizer
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=0.1)
)
opt_state = optimizer.init(params)

def gelu(x):
  return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))

def single_layer(x, w):
  return x @ w

def forward(params, x):
  for i in range(1, 5):
    x = gelu(single_layer(x, params[f'w{i}']))
  return x

def compute_loss(params, inputs):
  outputs = forward(params, inputs)
  loss = jnp.mean((outputs - inputs) ** 2)
  l2_reg = 0.001 * sum(jnp.sum(w ** 2) for w in jax.tree_util.tree_leaves(params))
  return loss + l2_reg

def step(params, opt_state, inputs):
  grads = jax.grad(lambda p: compute_loss(p, inputs))(params)
  updates, new_opt_state = optimizer.update(grads, opt_state, params)
  return optax.apply_updates(params, updates), new_opt_state

# JIT compile the step function with proper sharding
step = jax.jit(step, donate_argnums=(0, 1))

# Run a optimization step
new_params, new_opt_state = step(params, opt_state, input)

# Analyze memory usage
compiled_step = step.lower(params, opt_state, input).compile()
compiled_stats = compiled_step.memory_analysis()

if compiled_stats is not None:
  total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
      + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
  print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB")
  print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} GB")
  print(f"Total size: {total / (1024**3):.2f} GB")
Temp size: 2.11 GB
Argument size: 2.49 GB
Total size: 4.59 GB

优化器状态卸载可以按如下方式实现。

设置 Sharding 和内存类型#

采用 jax.sharding.SingleDeivceSharding() 来简化设备和主机内存类型的 shardings。在模型状态初始化期间,使用 device_put() 将优化器状态移动到主机。

模型和训练步骤实现#

接下来,定义模型架构、损失函数和训练步骤。这里新增的关键是,由于优化器状态在设备上的参数更新需要,因此在每个训练步骤的开始时通过 device_put() 将其移动到设备内存。

运行和比较结果#

设置好 sharding 后,优化器状态被移动到主机内存,并且使用 jax.jit() 运行 step 函数。

step 函数的 JIT 编译使用几个重要参数

  • donate_argnums=(0,):表示第一个参数(参数)可以就地修改,允许 JAX 重用其内存

  • out_shardings:指定输出张量应如何在网格(设备和主机)之间进行分片

# Create sharding specifications for device and host memory
s_dev = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")

def step(params, opt_state, inputs):
  grads = jax.grad(lambda p: compute_loss(p, inputs))(params)
  opt_state = jax.device_put(opt_state, s_dev)
  updates, new_opt_state = optimizer.update(grads, opt_state, params)
  new_params = optax.apply_updates(params, updates)
  return new_params, new_opt_state

params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}
opt_state = optimizer.init(params)

# Initialize optimizer
optimizer = optax.chain(
  optax.clip_by_global_norm(1.0),
  optax.adam(learning_rate=0.1)
)

# Optimizer state is placed on the host during initialization
opt_state = jax.device_put(opt_state, s_host)

# JIT compile the step function with proper sharding and memory optimization
step = jax.jit(
  step,
  donate_argnums=(0,),
  out_shardings=(s_dev, s_host)
)

# Run an optimization step
new_params, offload_opt_state = step(params, opt_state, input)

# Analyze memory usage
compiled_step = step.lower(params, opt_state, input).compile()
compiled_stats = compiled_step.memory_analysis()
if compiled_stats is not None:
  total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
      + compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
  print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB")
  print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} MB")
  print(f"Total size: {total / (1024**3):.2f} GB")
Temp size: 1.91 GB
Argument size: 0.96 MB
Total size: 2.87 GB

此实现演示了如何

  1. devicepinned_host 设置 sharding 规范

  2. 通过 jax.device_put() 在主机和设备内存之间移动优化器状态

  3. 使用 out_shardings 以确保正确的内存放置

  4. 显示内存使用情况

此实现演示了如何通过参数大小和临时内存之间的权衡,将优化器状态卸载到主机内存可以减少设备内存使用。

内存分析

  1. 参数大小减小

    • 优化器状态是 jax.jit() 函数的参数

    • 通过将这些状态卸载到主机内存,设备上的参数大小得以减小

  2. 临时内存影响

    • 卸载会增加临时内存使用

    • 这是因为优化器状态的输出在复制到主机之前需要内存缓冲区

    • 由于主机-设备传输,这些临时缓冲区的内存生命周期得以延长

  3. 延迟隐藏调度

    • JAX 使用 XLA 的延迟隐藏调度来重叠计算与主机-设备传输

    • 重叠会导致张量具有更长的生命周期,这会增加设备的内存压力

    • 这种自适应行为有助于保持稳定的内存使用,同时仍提供一些性能优势

  4. 内存权衡

    • 使用卸载时的总内存大小:2.87 GB

    • 不使用卸载时的总内存大小:4.59 GB

    • 净内存节省:1.72 GB

虽然卸载会增加临时内存的使用,但参数大小的减小足以弥补这种增加,从而导致设备内存使用量的总体减少。

注意:可以使用 jax.tree_util.tree_mapjnp.allclose 来比较优化器状态的数值等价性,但为简洁起见,此处省略了此验证步骤。

主机卸载工具#

上面使用了 jax.stages.Compiled.memory_analysis() API 来获取内存使用信息。有关设备内存分析,请参阅 :doc:device_memory_profilingProfiling and Tracing 中描述的分析工具可以帮助衡量主机卸载带来的内存节省和性能影响。