GPU 性能优化技巧#

本文档侧重于神经网络工作负载的性能优化技巧。

矩阵乘法精度#

在较新的 GPU 代系(如 Nvidia A100 或更高版本)上,建议以 bfloat16 精度执行大多数计算。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例:

XLA 性能标志#

注意

JAX-Toolbox 也有一个关于 NVIDIA XLA 性能标志的页面。

XLA 标志的存在及其具体行为可能取决于 jaxlib 的版本。

截至 jaxlib==0.4.18(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提升性能。其中一些与 GPU 之间的通信有关,因此仅在多设备计算时有效,而另一些则与每个设备上的代码生成有关。

在未来的版本中,其中一些可能会默认设置。

这些标志可以通过 XLA_FLAGS shell 环境变量进行设置。例如,我们可以将其添加到 Python 文件的顶部:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

代码生成标志#

  • –xla_gpu_triton_gemm_any 对于所有支持的 GEMM(矩阵乘法),使用基于 Triton 的 GEMM 发射器。默认值为 False。

通信优化技巧#

自动与手动 PGLE#

配置引导延迟估计器 (PGLE) 工作流会测量计算和集合通信的实际运行时间,并将这些性能分析信息反馈给 XLA 编译器,以便做出更好的调度决策。

配置引导延迟估计器既可以手动使用,也可以自动使用。在自动模式下,JAX 会在单次运行中收集性能分析信息并重新编译模块。而在手动模式下,你需要运行任务两次:第一次用于收集和保存分析数据,第二次用于利用已提供的数据进行编译和运行。

重要提示:下文所述的两种 PGLE 工作流所使用的 JAX 分析器不能与 NVIDIA Nsight Systems 分析器共存。这种限制可以通过使用下文所述的 JAX 编译缓存来规避。

自动 PGLE#

通过设置以下环境变量可以开启自动 PGLE:

强制:

export JAX_ENABLE_PGLE=true

# For JAX version <= 0.5.0 make sure to include:
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"

可选:

export JAX_PGLE_PROFILING_RUNS=3
export JAX_PGLE_AGGREGATION_PERCENTILE=85

# Right now the auto PGLE profile collection doesn't work with command buffer.
# If the command buffer is enabled, Auto PGLE will disable it during profile
# collection and enable it back after the recompilation. If you need to have a
# consistent command buffer logic with and with PGLE profile you can disable it
# manually:
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"

或者在 JAX 中,可以按照以下方式进行设置:

import jax
from jax._src import config

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  # Run with the profiler collecting performance information.
  train_step()
  # Automatically re-compile with PGLE profile results
  train_step()
  ...

你可以通过更改 JAX_PGLE_PROFILING_RUNS 来控制用于收集分析数据的重运行次数。增加该参数会获得更精确的性能分析信息,但也会增加非优化训练步骤的数量。

如果步骤之间的性能波动过大,导致难以过滤掉无关的测量值,减小 JAX_PGLE_AGGREGATION_PERCENTILE 参数可能会有所帮助。

注意: 自动 PGLE 不适用于预编译模块。由于 JAX 需要在执行期间重新编译模块,因此自动 PGLE 既不适用于 AoT(提前编译),也不适用于以下情况:

import jax
from jax._src import config

train_step_compiled = train_step().lower().compile()

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  train_step_compiled()
  # No effect since module was pre-compiled.
  train_step_compiled()

使用 AutoPGLE 时收集 NVIDIA Nsight Systems 性能分析数据#

jax#24910(JAX v0.5.1 及更高版本)添加了一个新的 JAX 配置选项 JAX_COMPILATION_CACHE_EXPECT_PGLE,该选项告知 JAX 尝试从持久编译缓存中加载经 PGLE 优化的编译函数。

这允许执行两步流程,其中第一步将 PGLE 优化的函数写入缓存:

export JAX_ENABLE_COMPILATION_CACHE=yes          # not strictly needed, on by default
export JAX_COMPILATION_CACHE_DIR=/root/jax_cache
JAX_ENABLE_PGLE=yes python my-model.py

第二步使用 Nsight Systems,并从缓存中加载 PGLE 优化的函数:

JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py

有关持久编译缓存及其潜在问题的更多信息,请参阅此页面

手动 PGLE#

如果你仍然希望使用手动配置引导延迟估计器,XLA/GPU 的工作流如下:

    1. 运行一次工作负载,启用异步集合通信和延迟隐藏调度程序。

你可以通过以下设置实现:

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
    1. 使用 JAX 分析器收集并后处理性能分析数据,将提取出的指令延迟保存到二进制 protobuf 文件中。

import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler

# Define your profile directory
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)

# run your workflow
# for i in range(10):
#   train_step()

# Stop trace
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)

# Post process the profile
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))

# Save the profile proto to a file.
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)

此步骤完成后,你将在代码中打印出的 rundir 目录下获得一个 profile.pb 文件。

    1. 再次运行工作负载,并将该文件输入到编译过程中。

你需要将 profile.pb 文件传递给 --xla_gpu_pgle_profile_file_or_directory_path 标志。

 export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"

要启用 XLA 中的日志记录并检查配置是否正确,请将日志级别设置为包含 INFO

export TF_CPP_MIN_LOG_LEVEL=0

运行实际工作流。如果在运行日志中发现了这些记录,则说明分析器已在延迟隐藏调度程序中使用:

2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator

标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度程序,以高效地重叠异步通信与计算。默认值为 False。

  • –xla_gpu_memory_limit_slop_factor 此标志作为应用到总可用内存的乘数,创建一个阈值,引导延迟隐藏调度程序 (LHS) 在减少内存占用和延迟隐藏优化之间取得平衡。默认值为 95。

    该因子有效地为编译器传递建立了一个内存限制,决定调度程序应优先考虑:

    1. 内存缩减:当内存使用量接近或超过计算出的阈值时。

    2. 延迟隐藏:当内存使用量低于阈值时,允许更激进的优化,这些优化可能会暂时增加内存使用量,但能提升整体性能。

    通过调整此因子,用户可以微调内存效率与性能优化之间的权衡。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于微调何时将多个小的 AllGather/ReduceScatter/AllReduce 操作合并为一个大的操作,以减少跨设备通信花费的时间。例如,对于基于 Transformer 的工作负载,考虑将 AllGather/ReduceScatter 阈值调整得足够高,以合并至少一个 Transformer 层权重的 AllGather/ReduceScatter 操作。默认情况下,combine_threshold_bytes 设置为 256。

GPU 上的流水线并行#

使用 XLA 标志#

XLA 实现了基于 SPMD 的流水线并行优化。这是一种扩展技术,将前向和后向传递拆分为多个流水线阶段。每个设备(或设备组)处理上一个流水线阶段(或流水线输入)的结果,并将部分结果发送到下一个阶段,直到到达流水线末尾。当计算延迟大于通信延迟时,此优化效果最佳。在编译时,操作将被重排以实现通信与计算的重叠。

为了获得优化的调度,我们推荐使用这些 XLA 标志:

--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''
--xla_disable_hlo_passes=collective-permute-motion
--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE

以下 JAX 示例展示了一种将通信操作调度为与计算重叠的模式。在此示例中,我们将展示如何使用 4 个组成通信环(设备 0 -> 设备 1 -> 设备 2 -> 设备 3 -> 设备 0)的 GPU 来设置优化的流水线并行调度。我们将模式 0 -> 1 -> 2 -> 3 称为前向边,将 3 -> 0 称为后向边。

# Imports and setup
import functools
import jax
from jax import sharding
from jax.experimental import mesh_utils
import jax.numpy as jnp
import jax.random

NUM_DEVICES = 4
NUM_MICROBATCHES = 5
NUM_CIRC_REPEATS = 2
CONTRACTING_DIM_SIZE = 4096
NON_CONTRACTING_DIM_SIZE = 8192
COMPUTE_INTENSITY = 32

# Creates a collective permute for the "forward edge".
# 0->1, 1->2, ... (N-2)->(N-1)
def shift_right(arr):
  padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
  # Use lax.slice to guarantee the gradient is a pad.
  return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)


# Creates a collective permute for the "back edge".
# (N-1)->0
def cycle_back(arr):
  padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
  return jax.lax.slice(
      jnp.pad(arr, padding),
      [NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
      (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
  )


def select_on_first_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
  return jnp.where(is_first_device, then_value, else_value)


def select_on_last_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_last_device = (
      jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
  )
  return jnp.where(is_last_device, then_value, else_value)


def select_on_first_cycle(i, then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_cycle = i < NUM_MICROBATCHES
  return jnp.where(is_first_cycle, then_value, else_value)


def while_body(carry, i):
  """Body of the pipeline while loop."""
  weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry

  # Read input data from input buffer.
  input_data = jax.lax.dynamic_slice(
      input_buffer,
      (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Collective permute on the "forward edge" shifts data to the next stage.
  fwd_edge_data = shift_right(fwd_edge_data)

  # Select compute argument based on device and pipeline cycle.
  compute_argument = select_on_first_device(
      select_on_first_cycle(i, input_data, bwd_edge_data),
      fwd_edge_data,
  ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))

  # A few matmuls to simulate compute.
  tmp = compute_argument
  for _ in range(COMPUTE_INTENSITY):
    tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
  compute_result = tmp.reshape(
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  )

  # Read data from buffer to pass it to the first device of the pipeline on the
  # "back edge".
  bwd_edge_data = jax.lax.dynamic_slice(
      output_buffer,
      (0, (1 + i) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Collective permute on the "back edge" passes data to the first device.
  bwd_edge_data = cycle_back(bwd_edge_data)

  # Update output buffer. We do this after reading from it to avoid the data
  # dependency.
  output_buffer = jax.lax.dynamic_update_slice(
      output_buffer,
      compute_result,
      (0, (2 + i) % NUM_MICROBATCHES, 0, 0),
  )

  fwd_edge_data = compute_result
  carry = (
      weights,
      input_buffer,
      output_buffer,
      fwd_edge_data,
      bwd_edge_data,
  )
  return carry, i


@functools.partial(jax.jit, static_argnames=["mesh"])
def entry_computation(weights, input_buffer, mesh):

  # Init output buffer.
  output_buffer = jnp.zeros_like(input_buffer)

  # Init dummy data for forward and backward edge passed through the while loop.
  dummy_data = jnp.zeros(
      shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  ).astype(jnp.float32)
  dummy_data = jax.device_put(
      dummy_data,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Start pipeline.
  carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
  num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
  carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
  _, _, output_buffer, _, _ = carry

  return output_buffer


def main(_):

  # Expect constant number of devices.
  assert NUM_DEVICES == jax.local_device_count()

  # Create mesh.
  mesh = sharding.Mesh(
      mesh_utils.create_device_mesh([NUM_DEVICES]),
      axis_names=["x"],
  )

  # Init weights.
  weights = 1.0 / CONTRACTING_DIM_SIZE
  weights = jax.lax.broadcast_in_dim(
      weights,
      shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
      broadcast_dimensions=(),
  )
  weights = jax.device_put(
      weights,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Init random input and replicate it across all devices.
  random_key = jax.random.key(0)
  input_buffer = jax.random.uniform(
      random_key,
      shape=(
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
  )
  input_buffer = jax.lax.broadcast_in_dim(
      input_buffer,
      shape=(
          NUM_DEVICES,
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
      broadcast_dimensions=[1, 2, 3],
  )
  input_buffer = jax.device_put(
      input_buffer,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Run computation.
  output_buffer = entry_computation(weights, input_buffer, mesh)
  print(f"output_buffer = \n{output_buffer}")

使用 psendprecv#

上述 JAX 示例会降低为 collective-permute HLO 指令,在 GPU 上通过 ncclSendncclRecv 实现。对于想要更精细地控制集合通信顺序的用户,可以直接使用 jax.lax.psendjax.lax.precv。从句法上讲,这两个函数类似于它们的 HLO 等价物。用户应注意,当单个 psendprecv 中的源-目标对形成循环,或者 psend 没有对应的 precv(反之亦然)时,程序会死锁。

如果在设备通信模式中确实需要循环,可以通过以下方式避免死锁:(1)确保没有任何单个 psendprecv 函数的源-目标对包含循环;(2)插入虚假数据依赖以序列化发送/接收对。在 psend/precv 对之间不能调度任何集合通信,这只能通过 JAX 级别的 jax.lax.optimization_barrier 来控制。shard_map_test.py 文件中的测试用例 test_psend_precv_basic_with_no_deadlock_cycle 就是一个例子。

上一节中的流水线并行示例使用了 --xla_gpu_experimental_pipeline_parallelism_opt_level XLA 标志。如果手动流水线化,同样的程序可以使用 psendprecv 重写,而无需使用该标志。

## same setup and imports
def while_body(carry, i):
  (
      weights,
      input_buffer,
      output_buffer,
      prev_compute_res,
      prev_stage_slice_fwd,
      prev_stage_slice_bwd,
  ) = carry

  # Read input data from input buffer.
  input_slice = jax.lax.dynamic_slice(
      input_buffer,
      (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # send_fwd
  fwd_send_token = jax.lax.psend(
      prev_compute_res,
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  # Select compute argument based on device and pipeline cycle
  compute_argument = select_on_first_device(
      select_on_first_cycle(i, input_slice, prev_stage_slice_bwd),
      prev_stage_slice_fwd,
  ).reshape((1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))

  tmp = compute_argument
  for _ in range(COMPUTE_INTENSITY):
    tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
  compute_result = tmp.reshape(
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  )

  buffer_slice_for_bwd_ppermute = jax.lax.dynamic_slice(
      output_buffer,
      (0, (i + 1) % NUM_MICROBATCHES, 0, 0),
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # make sure ppermute is scheduled after send_fwd
  buffer_slice_for_bwd_ppermute_after_send_fwd, _ = (
      jax.lax.optimization_barrier(
          (buffer_slice_for_bwd_ppermute, fwd_send_token)
      )
  )
  # ppermute_bwd
  ppermute_bwd_data = jax.lax.ppermute(
      buffer_slice_for_bwd_ppermute_after_send_fwd,
      axis_name="x",
      perm=[(3, 0)],
  )

  # make sure recv is scheduled after ppermute
  precv_token, _ = jax.lax.optimization_barrier(
      (jax.lax.create_token(), ppermute_bwd_data)
  )

  # recv_fwd, matches the send_fwd in the next iteration
  fwd_recv_data = jax.lax.precv(
      precv_token,
      out_shape=jax.ShapeDtypeStruct(
          input_slice.shape, input_slice.dtype
      ),
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )
  update_output_buffer = jax.lax.dynamic_update_slice(
      output_buffer,
      compute_result,
      (0, (i + 2) % NUM_MICROBATCHES, 0, 0),
  )
  carry = (
      weights,
      input_buffer,
      update_output_buffer,
      compute_result,
      fwd_recv_data,
      ppermute_bwd_data,
  )
  return carry, i


def entry_computation(
    weights, input_buffer, dummy_data, mesh
):

  # Init output buffer.
  output_buffer = jnp.zeros_like(input_buffer)

  # Start pipeline.
  dummy_slice_fwd = jax.lax.precv(
      jax.lax.create_token(),
      jax.ShapeDtypeStruct(dummy_data.shape, dummy_data.dtype),
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  carry = (
      weights,
      input_buffer,
      output_buffer,
      dummy_slice_fwd,
      dummy_data,
      dummy_data,
  )

  num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
  carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))

  _ = jax.lax.psend(
      carry[3],
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  _, _, output_buffer, _, _, _ = carry

  return output_buffer


def main(_):

  # Expect constant number of devices.
  assert NUM_DEVICES == jax.local_device_count()

  # Create mesh.
  mesh = Mesh(
      mesh_utils.create_device_mesh([NUM_DEVICES]),
      axis_names=["x"],
  )
  # Init weights.
  weights = 1.0 / CONTRACTING_DIM_SIZE
  weights = jax.lax.broadcast_in_dim(
      weights,
      shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
      broadcast_dimensions=(),
  )
  weights = jax.device_put(
      weights, NamedSharding(mesh, P("x"))
  )
  # Init input.
  random_key = jax.random.key(0)
  input_buffer = jax.random.uniform(
      random_key,
      shape=(
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
  )
  input_buffer = jax.lax.broadcast_in_dim(
      input_buffer,
      shape=(
          NUM_DEVICES,
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
      broadcast_dimensions=[1, 2, 3],
  )

  input_buffer = jax.device_put(
      input_buffer,
      NamedSharding(mesh, P("x")),
  )
  # Init dummy data for forward and backward edge passed through the while
  # loop.
  dummy_slice = jnp.zeros(
      shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  ).astype(jnp.float32)
  dummy_data = jax.device_put(
      dummy_slice,
      NamedSharding(mesh, P("x")),
  )

  entry = partial(entry_computation, mesh=mesh)

  output_buffer = jax.jit(
      jax.shard_map(
          entry,
          mesh=mesh,
          in_specs=P("x"),
          out_specs=P("x"),
          check_vma=False,
      )
  )(weights, input_buffer, dummy_data)
  print(f"output_buffer = \n{output_buffer}")

NCCL 标志#

这些 Nvidia NCCL 标志值对于 Nvidia GPU 上的单主机多设备计算可能有用:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。目前看来,这些标志对多主机通信似乎没有用处。

多进程#

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 JIT 编译的计算。jax.distributed.initialize() API 在 SLURM 下运行时会自动识别该配置。不过,这仅是一个经验法则,在你的用例中测试“每个 GPU 一个进程”和“每个节点一个进程”这两种方式可能会有所帮助。