JAX 内存与主机卸载#
本教程提供了 JAX 中主机卸载技术的实用介绍,重点关注:
激活值卸载
参数卸载
优化器状态卸载
通过应用卸载策略,开发者可以更好地管理内存资源并减轻设备上的内存压力。为了有效实施这些策略,理解 JAX 数据放置和移动的核心机制至关重要。
卸载的基础构建块#
JAX 提供了几个关键组件,用于控制数据在主机和设备内存之间的存储和移动方式及位置。以下部分将探讨:
如何使用分片指定数据分布
如何控制主机和设备之间的内存放置
如何管理 JIT 编译函数中的数据移动
命名分片 (NamedSharding) 与内存类型 (Memory Kinds)#
NamedSharding
定义了数据如何在设备间分布。它包括:
基本数据分布配置
用于指定内存类型(
device
或pinned_host
)的memory_kind
参数默认情况下,
memory_kind
设置为device
内存用于创建具有修改后内存类型的新分片的
with_memory_kind
方法
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np
# Create mesh
# 1x1 mesh represents a single device with two named dimensions (x and y)
mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))
# Device sharding - partitions data along x and y dimensions
s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind="device")
# Host sharding - same partitioning but in pinned host memory
s_host = s_dev.with_memory_kind('pinned_host')
print(s_dev) # Shows device memory sharding
print(s_host) # Shows pinned host memory sharding
NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)
NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)
使用 device_put 进行数据放置#
jax.device_put()
是一个函数,根据分片规范将数组显式传输到指定的内存位置。
# Create a 2x4 array
arr = jnp.arange(8.0).reshape(2, 4)
# Move arrays to different memory locations based on sharding objects
arr_host = jax.device_put(arr, s_host) # Places in pinned host memory
arr_dev = jax.device_put(arr, s_dev) # Places in device memory
# Verify memory locations
print(arr_host.sharding.memory_kind) # Output: pinned_host
print(arr_dev.sharding.memory_kind) # Output: device
pinned_host
device
输出分片控制#
分片决定了数据如何在设备间拆分。JAX 提供了 out_shardings
来控制输出数组在离开 JIT 编译函数时如何分区。
关键特性
可以与输入分片不同
允许输出使用不同的内存类型
示例
设备输出分片#
f = jax.jit(lambda x:x, out_shardings=s_dev)
out_dev = f(arr_host)
print("Result value of H2D: \n", out_dev)
Result value of H2D:
[[0. 1. 2. 3.]
[4. 5. 6. 7.]]
在计算需要时将数据从主机内存移动到设备内存是主机卸载的精髓。在此示例中,使用 jax.device_put()
执行此传输以优化性能。
# Instead of the lambda function, add_func can be defined explicitly
# move data to device before computation
def add_func(x): # Move data to device and add one
x = jax.device_put(x, s_dev)
return x + 1
f = jax.jit(add_func, out_shardings=s_dev)
out_dev = f(arr_host)
print("Result value of H2D and add 1 in device memory: \n", out_dev)
Result value of H2D and add 1 in device memory:
[[1. 2. 3. 4.]
[5. 6. 7. 8.]]
主机输出分片#
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_host = f(arr_host) # Input arrays in the device memory while output arrays in the host memory
print("Result value of D2H: \n", out_host)
Result value of D2H:
[[0. 1. 2. 3.]
[4. 5. 6. 7.]]
激活值卸载#
在深入探讨激活值卸载之前,我们先来看一下基线代码。
此代码实现了一个简单的神经网络,包含 10 个层,每个层由两个线性变换组成。该代码展示了基本的内存使用模式,并为比较卸载优化技术提供了基础。
关键组件
每个层由两个顺序线性操作组成
第一次乘法:
x @ w1
第二次乘法:
y @ w2
使用 JAX 的扫描 (scan) 操作的 10 层网络
内存使用分析
使用 JIT 编译进行梯度计算
要分析 JAX 中的内存使用情况,可以在编译后的函数上使用 :func:`jax.stages.Compiled.memory_analysis` 方法。这提供了计算过程中内存消耗的详细统计信息。关键指标包括临时内存大小、参数大小、输出大小和别名大小。要计算总内存使用量,请将临时内存、参数和输出大小相加,然后减去别名大小,以避免多次重复计算同一内存。这提供了一个汇总视图,说明设备内存在计算的不同方面如何被利用。
# Initialize input and weights with small values (0.0001)
input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256
w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices
w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices
def two_layers(x, w):
# Simple two-layer linear transformation
w1, w2 = w
y = x @ w1
return y @ w2, None
def scanned(w, x):
# Applies the layer function 10 times using JAX's scan operation
# Input: w (tuple of weight matrices), x (input matrix)
# Output: sum of the final layer's output
result = jax.lax.scan(two_layers, x, w)[0]
return jnp.sum(result)
# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation
# Analyze memory usage
compiled_step = f.lower((w1, w2), input).compile()
compiled_stats = compiled_step.memory_analysis()
if compiled_stats is not None:
# Calculate total memory usage including temporary storage, arguments, and outputs
# Subtract alias size to avoid double-counting memory shared between different components
total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
+ compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB")
print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB")
print(f"Total size: {total/(1024**2):.2f} MB")
# Execute the function and print sample results
result = f((w1, w2), input) # Execute the function with weights and input
print("Sample of results: ", result[0][0, 0, :5])
Temp size: 17.25 MB
Argument size: 20.25 MB
Total size: 57.50 MB
Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]
激活值卸载的详细内容可以在“使用 jax.checkpoint (即 jax.remat) 进行梯度检查点”教程中找到。激活值卸载通过在正向传播后将中间激活值移动到主机内存,并在反向传播需要计算梯度时将其移回设备内存来帮助管理内存。
为了有效实现激活值卸载,理解检查点名称和策略非常重要。这是一个简单的工作示例:
检查点名称#
checkpoint_name()
函数允许在计算过程中为激活值标记名称,以便进行内存管理。这是一个指定检查点名称 x
的简单示例。
from jax.ad_checkpoint import checkpoint_name
def layer_name(x, w):
w1, w2 = w
x = checkpoint_name(x, "x")
y = x @ w1
return y @ w2, None
检查点名称帮助系统决定是:
将激活值保留在设备内存中,还是
在计算过程中将其卸载到主机内存
这种模式在神经网络中很常见,其中多个变换按顺序应用于输入数据。
检查点策略#
此检查点策略实现了在计算过程中优化内存使用的内存管理策略。它通过三种策略处理中间值来管理内存:
在反向传播期间重新计算(默认行为)
存储在设备上
在正向传播后卸载到主机内存并在反向传播期间重新加载
from jax import checkpoint_policies as cp
policy = cp.save_and_offload_only_these_names(
names_which_can_be_saved=[], # No values stored on device
names_which_can_be_offloaded=["x"], # Offload activations labeled "x"
offload_src="device", # Move from device memory
offload_dst="pinned_host" # To pinned host memory
)
jax.lax.scan()
在 JAX 中常用于处理顺序操作(如 RNN 或 Transformer)。它可以与 JAX 的重新具象化(rematerialization)集成以处理顺序数据。
关键组件
jax.remat()
使用jax.remat()
创建层函数的重新具象化版本,并将检查点策略应用于层函数prevent_cse=False
启用 XLA 的公共子表达式消除以提高性能jax.lax.scan()
沿着一个轴迭代重新具象化层
def scanned(w, x):
remat_layer = jax.remat(layer_name,
policy=policy, # Use our offloading policy
prevent_cse=False) # Allow CSE optimizations
result = jax.lax.scan(remat_layer, x, w)[0]
return jnp.sum(result)
# Initialize input and weights with small values (0.0001)
input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256
w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices
w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices
# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation
# Analyze memory usage
compiled_step = f.lower((w1, w2), input).compile()
compiled_stats = compiled_step.memory_analysis()
if compiled_stats is not None:
total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
+ compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB")
print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB")
print(f"Total size: {total/(1024**2):.2f} MB")
result_activation = f((w1, w2), input) # Execute the function with weights and input
# Verify numerical correctness
are_close = jnp.allclose(
result_activation[0], # Result from activation offloading only
result[0], # Result from both activation and parameter offloading
rtol=1e-5,
atol=1e-5
)
print(f"Results match within tolerance: {are_close}")
print("Sample of results: ", result_activation[0][0, 0, :5])
Temp size: 6.50 MB
Argument size: 20.25 MB
Total size: 46.75 MB
Results match within tolerance: True
Sample of results: [3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07 3.8312336e-07]
激活值卸载将临时内存使用量从 17.25 MB 减少到 6.5 MB,而输入和输出参数大小保持不变。总共节省了 10.75 MB。这是通过在正向传播后将激活值 x
卸载到主机内存,并在反向传播前将其加载回设备内存来实现的。
激活值卸载总结#
激活值卸载通过以下方式为大型计算提供了强大的内存管理方法:
使用检查点名称标记特定激活值
应用策略控制激活值的存储位置和方式
支持 JAX 中常见的模式,如扫描操作
当设备内存不足时,将选定的激活值移动到主机内存
这种方法在处理那些否则会超出设备内存容量的大型模型时特别有用。
参数卸载#
模型参数(也称为权重)可以卸载到主机内存,以在初始化期间优化设备内存使用。这通过使用 jax.jit()
和指定主机内存类型的分片策略来实现。
虽然参数卸载和激活值卸载是不同的内存优化技术,但以下示例演示了在前面所示的激活值卸载实现基础上构建的参数卸载。
用于计算的参数放置#
与之前的 layer
函数不同,这里将 jax.device_put()
应用于在矩阵乘法之前将参数 w1
和 w2
移动到设备。这确保了参数在正向和反向传播期间都可在设备上使用。
请注意,激活值卸载的实现保持不变,使用相同的:
检查点名称
"x"
检查点策略
结合
jax.remat()
和jax.lax.scan()
的scanned
函数
使用主机卸载进行参数初始化#
在初始化期间,参数 w1
和 w2
被放置在主机内存中,然后才传递给 jax.jit()
函数 f
,同时 input
变量保留在设备上。
# Hybrid version: Both activation and parameter offloading
def hybrid_layer(x, w):
# Move model parameters w1 and w2 to host memory via device_put
w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)
x = checkpoint_name(x, "x") # Offload activation x to host memory
y = x @ w1
return y @ w2, None
def hybrid_scanned(w, x):
remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer
policy=policy, # Use offloading policy
prevent_cse=False) # Allow CSE optimizations
result = jax.lax.scan(remat_layer, x, w)[0]
return jnp.sum(result)
# Move model parameters w1 and w2 to the host via device_put
# Initialize input and weights with small values (0.0001)
wh1 = jax.device_put(w1, s_host)
wh2 = jax.device_put(w2, s_host)
# Compile and compute gradients of the scanned function
f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation
# Analyze memory usage
compiled_step = f.lower((wh1, wh2), input).compile()
compiled_stats = compiled_step.memory_analysis()
if compiled_stats is not None:
total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
+ compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**2):.2f} MB")
print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**2):.2f} MB")
print(f"Total size: {total / (1024**2):.2f} MB")
result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading
# Verify numerical correctness
are_close = jnp.allclose(
result_activation[0], # Result from activation offloading only
result_both[0], # Result from both activation and parameter offloading
rtol=1e-5,
atol=1e-5
)
print(f"Results match within tolerance: {are_close}")
Temp size: 4.75 MB
Argument size: 0.25 MB
Total size: 25.00 MB
Results match within tolerance: True
该实现演示了将模型参数与激活值一同卸载到主机内存如何显著减少设备内存使用。
内存分析#
基线内存使用
输入张量:0.25 MB (256 × 256 × 4 字节)
模型参数 (w1, w2):每个 10 MB (256 × 1024 × 4 字节 ≈ 每层 1 MB × 10 层)
内存使用比较
无参数卸载时的参数大小:20.25 MB (0.25 + 10 + 10)
带参数卸载时的参数大小:0.25 MB (仅剩输入)
无激活值卸载时的临时内存:17.25 MB
带激活值卸载时的临时内存:6.50 MB
带激活值和参数卸载时的临时内存:4.75 MB
关键优化#
参数卸载:将参数 (w1, w2) 移动到主机内存将参数大小减少了 20 MB(从 20.25 MB 到 0.25 MB)。
激活值卸载:将激活值移动到主机内存将临时内存使用量减少了 10.75 MB(从 17.25 MB 到 6.50 MB)。
混合策略:激活值卸载的重新具象化有助于避免将权重保留在设备上,并将临时内存使用量减少了 1.75 MB(从 6.50 MB 到 4.75 MB)。没有它,JAX 会急于在反向传播时保留权重的设备副本。
结果#
总内存节省:33.5 MB (20 MB + 10.75 MB + 1.75 MB)
这种混合方法表明,参数和激活值卸载协同工作,可以在保持计算正确性的同时实现显著的内存减少。
参数卸载的局限性#
jax.lax.scan()
对于有效的参数管理至关重要。使用显式 for 循环会导致参数持续占用设备内存,从而导致与不使用参数卸载时相同的内存使用量。尽管 jax.lax.scan()
允许指定扫描轴,但参数卸载目前仅在沿着轴 0 扫描时有效。沿着其他轴扫描会在编译期间生成一个 transpose
操作,然后才将参数返回到设备,这代价高昂且并非所有平台都支持。
卸载性能可能因不同的设备类型而异。它可能会因为主机和设备之间的内存传输而降低性能,因此在设计优化策略时考虑这种权衡很重要。
优化器状态卸载#
优化器状态卸载是一种内存管理技术,它将优化器状态存储在主机内存中而不是设备内存中。这种方法在优化器状态较大时特别有用,因为它减少了设备内存使用。
使用 Adam 优化器的基本 JAX 实现可以作为起点,其中所有张量都存储在设备上。这将作为引入优化器状态卸载之前的参考实现。
基本实现#
在本节中,让我们使用 Adam 优化器实现一个简单的模型。此实现有助于在探索优化器状态卸载之前建立基线行为。它对于理解大规模神经网络训练中的内存模式特别有用。
在下面的代码示例中,包含了一个神经网络训练循环,以使用 JAX 和 Optax 的 Adam 优化器。该网络由四个带有 GELU 激活函数的线性层组成,处理大小为 7168x7168 的大型矩阵。训练过程包括:
正向传播:输入流经四个层,每个层应用线性变换,然后是 GELU 激活
损失计算:计算输出和输入之间的均方误差,加上 L2 正则化
反向传播:使用自动微分计算梯度
优化步骤:使用 Adam 优化器和梯度裁剪更新参数
该代码使用 JIT 编译来优化性能,并包含内存使用分析以监控训练期间所需的计算资源。内存分析提供了关于优化步骤期间临时内存使用、参数大小和总内存消耗的洞察。
import optax
DIM = 7168
# Initialize data and parameter w1, w2, w3 and w4
input = jnp.ones((DIM, DIM))
params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}
# Initialize optimizer
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=0.1)
)
opt_state = optimizer.init(params)
def gelu(x):
return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))
def single_layer(x, w):
return x @ w
def forward(params, x):
for i in range(1, 5):
x = gelu(single_layer(x, params[f'w{i}']))
return x
def compute_loss(params, inputs):
outputs = forward(params, inputs)
loss = jnp.mean((outputs - inputs) ** 2)
l2_reg = 0.001 * sum(jnp.sum(w ** 2) for w in jax.tree_util.tree_leaves(params))
return loss + l2_reg
def step(params, opt_state, inputs):
grads = jax.grad(lambda p: compute_loss(p, inputs))(params)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates), new_opt_state
# JIT compile the step function with proper sharding
step = jax.jit(step, donate_argnums=(0, 1))
# Run a optimization step
new_params, new_opt_state = step(params, opt_state, input)
# Analyze memory usage
compiled_step = step.lower(params, opt_state, input).compile()
compiled_stats = compiled_step.memory_analysis()
if compiled_stats is not None:
total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
+ compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB")
print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} GB")
print(f"Total size: {total / (1024**3):.2f} GB")
Temp size: 2.11 GB
Argument size: 2.49 GB
Total size: 4.59 GB
优化器状态卸载可以按如下方式实现。
设置分片与内存类型#
采用 jax.sharding.SingleDeviceSharding()
来简化设备和主机内存类型的分片。在模型状态初始化期间,使用 device_put()
将优化器状态移动到主机。
模型与训练步骤实现#
接下来,定义模型架构、损失函数和训练步骤。这里的关键新增点是在每个训练步骤开始时通过 device_put()
将优化器状态移动到设备内存,因为设备上的参数更新需要它。
运行并比较结果#
设置分片后,优化器状态被移动到主机内存,并使用 jax.jit()
运行步进函数。
步进函数的 JIT 编译使用了几个重要参数:
donate_argnums=(0,)
:表示第一个参数(参数)可以就地修改,允许 JAX 重用其内存out_shardings
:指定输出张量应如何在网格(设备和主机)上分片
# Create sharding specifications for device and host memory
s_dev = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")
def step(params, opt_state, inputs):
grads = jax.grad(lambda p: compute_loss(p, inputs))(params)
opt_state = jax.device_put(opt_state, s_dev)
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state
params = {f'w{i}': jnp.ones((DIM, DIM)) for i in range(1, 5)}
opt_state = optimizer.init(params)
# Initialize optimizer
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=0.1)
)
# Optimizer state is placed on the host during initialization
opt_state = jax.device_put(opt_state, s_host)
# JIT compile the step function with proper sharding and memory optimization
step = jax.jit(
step,
donate_argnums=(0,),
out_shardings=(s_dev, s_host)
)
# Run an optimization step
new_params, offload_opt_state = step(params, opt_state, input)
# Analyze memory usage
compiled_step = step.lower(params, opt_state, input).compile()
compiled_stats = compiled_step.memory_analysis()
if compiled_stats is not None:
total = compiled_stats.temp_size_in_bytes + compiled_stats.argument_size_in_bytes \
+ compiled_stats.output_size_in_bytes - compiled_stats.alias_size_in_bytes
print(f"Temp size: {compiled_stats.temp_size_in_bytes / (1024**3):.2f} GB")
print(f"Argument size: {compiled_stats.argument_size_in_bytes / (1024**3):.2f} MB")
print(f"Total size: {total / (1024**3):.2f} GB")
Temp size: 1.91 GB
Argument size: 0.96 MB
Total size: 2.87 GB
此实现演示了如何:
为
device
和pinned_host
设置分片规范通过
jax.device_put()
在主机和设备内存之间移动优化器状态使用
out_shardings
确保正确的内存放置显示内存使用情况
此实现演示了将优化器状态卸载到主机内存如何通过在参数大小和临时内存之间进行权衡来减少设备内存使用。
内存分析
参数大小减少
优化器状态是
jax.jit()
函数的参数通过将这些状态卸载到主机内存,设备上的参数大小减少了
临时内存影响
卸载增加了临时内存使用
这是因为优化器状态的输出在复制到主机之前需要内存缓冲区
由于主机-设备传输,这些临时缓冲区的内存生命周期被延长
延迟隐藏调度
JAX 使用 XLA 的延迟隐藏调度来重叠计算与主机-设备传输
重叠可能导致张量具有更大的生命周期,从而增加设备上的内存压力
这种自适应行为有助于保持稳定的内存使用,同时仍提供一些性能优势
内存权衡
卸载后的总内存大小:2.87 GB
未卸载时的总内存大小:4.59 GB
净内存节省:1.72 GB
尽管卸载增加了临时内存使用量,但参数大小的减少足以弥补这一增长,从而总体上减少了设备内存使用。
注意:可以使用 jax.tree_util.tree_map
和 jnp.allclose
比较优化器状态的数值等效性,但此处为简洁起见省略了此验证步骤。
主机卸载工具#
上面使用了 :func:`jax.stages.Compiled.memory_analysis` API 来获取内存使用信息。有关设备内存分析,请参阅 :doc:`device_memory_profiling`。“性能分析与追踪”中描述的性能分析工具可以帮助衡量主机卸载带来的内存节省和性能影响。