GPU 性能优化技巧#

本文档重点介绍神经网络工作负载的性能优化技巧

矩阵乘法精度#

在较新的 GPU 代(例如 Nvidia A100 或更高版本)上,建议以 bfloat16 精度执行大多数计算。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例

XLA 性能标志#

注意

JAX-Toolbox 也有关于 NVIDIA XLA 性能标志的页面。

XLA 标志的存在和确切行为可能取决于 jaxlib 的版本。

截至 jaxlib==0.4.182023 年 10 月 6 日发布),设置这些 XLA 标志可以提高性能。其中一些与 GPU 之间的通信有关,因此仅在多设备上运行计算时才相关,而另一些则与每个设备上的代码生成有关。

其中一些在未来的版本中可能会默认设置。

这些标志可以通过 XLA_FLAGS shell 环境变量设置。例如,我们可以将其添加到 Python 文件顶部

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

更多示例,请参阅 在 Nvidia GPU 上进行 Pax 训练推荐的 XLA 标志

代码生成标志#

  • –xla_gpu_triton_gemm_any 对任何支持的 GEMM(矩阵乘法)使用基于 Triton 的 GEMM 发射器。默认值为 False。

通信技巧#

自动和手动 PGLE#

性能分析指导延迟估计器(Profile Guided Latency Estimator,PGLE)工作流测量计算和集合操作的实际运行时间,然后将分析信息反馈给 XLA 编译器,以做出更好的调度决策。

性能分析指导延迟估计器可以手动或自动使用。在自动模式下,JAX 会在单次运行中收集性能分析信息并重新编译模块。而在手动模式下,您需要运行任务两次,第一次用于收集和保存性能分析数据,第二次则使用提供的数据进行编译和运行。

重要:JAX 性能分析器(下文所述的两种 PGLE 工作流都会用到)不能与 NVIDIA Nsight Systems 性能分析器同时存在。可以通过使用 JAX 编译缓存来避免此限制,如下所述。

自动 PGLE#

可以通过设置以下环境变量来开启自动 PGLE

必需

export JAX_ENABLE_PGLE=true

# For JAX version <= 0.5.0 make sure to include:
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"

可选

export JAX_PGLE_PROFILING_RUNS=3
export JAX_PGLE_AGGREGATION_PERCENTILE=85

# Right now the auto PGLE profile collection doesn't work with command buffer.
# If the command buffer is enabled, Auto PGLE will disable it during profile
# collection and enable it back after the recompilation. If you need to have a
# consistent command buffer logic with and with PGLE profile you can disable it
# manually:
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"

或者在 JAX 中可以按如下方式设置

import jax
from jax._src import config

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  # Run with the profiler collecting performance information.
  train_step()
  # Automatically re-compile with PGLE profile results
  train_step()
  ...

您可以通过更改 JAX_PGLE_PROFILING_RUNS 来控制用于收集性能分析数据的重复运行次数。增加此参数将获得更好的性能分析信息,但也会增加非优化训练步骤的数量。

在步骤间的性能噪音过大,无法滤除非相关测量值的情况下,降低 JAX_PGLE_AGGREGATION_PERCENTILE 参数可能会有所帮助。

注意:自动 PGLE 不适用于预编译模块。由于 JAX 需要在执行期间重新编译模块,因此自动 PGLE 既不适用于 AoT 也不适用于以下情况

import jax
from jax._src import config

train_step_compiled = train_step().lower().compile()

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  train_step_compiled()
  # No effect since module was pre-compiled.
  train_step_compiled()

使用 AutoPGLE 时收集 NVIDIA Nsight Systems 性能分析数据#

jax#24910 (JAX v0.5.1 及更高版本) 添加了一个新的 JAX 配置选项 JAX_COMPILATION_CACHE_EXPECT_PGLE,它指示 JAX 尝试从持久编译缓存加载经 PGLE 优化的编译函数。

这允许一个两步过程,第一步将 PGLE 优化的函数写入缓存

export JAX_ENABLE_COMPILATION_CACHE=yes          # not strictly needed, on by default
export JAX_COMPILATION_CACHE_DIR=/root/jax_cache
JAX_ENABLE_PGLE=yes python my-model.py

第二步使用 Nsight Systems 并从缓存加载经 PGLE 优化的函数

JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py

另请参阅此页面,了解有关持久编译缓存和可能存在的陷阱的更多信息。

手动 PGLE#

如果您仍想使用手动性能分析指导延迟估计器,XLA/GPU 中的工作流如下

    1. 运行您的工作负载一次,并启用异步集合操作和延迟隐藏调度器。

您可以通过设置以下项来实现

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
    1. 使用 JAX 性能分析器收集并后处理性能分析数据,将提取的指令延迟保存到二进制 protobuf 文件中。

import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler

# Define your profile directory
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)

# run your workflow
# for i in range(10):
#   train_step()

# Stop trace
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)

# Post process the profile
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))

# Save the profile proto to a file.
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)

此步骤之后,您将在代码中打印的 rundir 下获得一个 profile.pb 文件。

    1. 再次运行工作负载,将该文件作为编译输入。

您需要将 profile.pb 文件传递给 --xla_gpu_pgle_profile_file_or_directory_path 标志。

 export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"

要在 XLA 中启用日志记录并检查性能分析数据是否良好,请将日志级别设置为包含 INFO

export TF_CPP_MIN_LOG_LEVEL=0

运行实际工作流,如果您在运行日志中发现了这些日志信息,则表示性能分析器已在延迟隐藏调度器中使用。

2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator

标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效重叠异步通信与计算。默认值为 False。

  • –xla_gpu_memory_limit_slop_factor 此标志用作应用于总可用内存的乘数,创建了一个阈值,指导延迟隐藏调度器(LHS)在内存减少和延迟隐藏优化之间取得平衡。默认值为 95。

    此因子有效地为编译器通过设置了内存限制,决定了调度器何时应优先考虑

    1. 内存减少:当内存使用接近或超过计算出的阈值时。

    2. 延迟隐藏:当内存使用低于阈值时,允许进行更激进的优化,这可能会暂时增加内存使用,但会提高整体性能。

    通过调整此因子,用户可以微调内存效率和性能优化之间的权衡。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的 AllGather/ReduceScatter/AllReduce 组合成一个大的 AllGather/ReduceScatter/AllReduce,以减少跨设备通信的时间。例如,对于基于 Transformer 的工作负载中的 AllGather/ReduceScatter 阈值,请考虑将其调整得足够高,以便至少组合一个 Transformer 层的权重 AllGather`/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

GPU 上的流水线并行#

使用 XLA 标志#

XLA 实现了基于 SPMD 的流水线并行优化。这是一种扩展技术,其中前向和后向传播被分成多个流水线阶段。每个设备(或设备组)处理上一个流水线阶段的结果(或流水线输入),并将其部分结果发送到下一个阶段,直到流水线结束。当计算延迟大于通信延迟时,此优化效果最佳。在编译时,操作将重新排列以使通信与计算重叠。

为了获得优化的调度,我们推荐这些 XLA 标志

--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''
--xla_disable_hlo_passes=collective-permute-motion
--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE

以下 JAX 示例演示了一种通信操作与计算重叠调度的模式。在此示例中,我们将说明如何使用 4 个 GPU(它们形成一个通信环:设备 0 -> 设备 1 -> 设备 2 -> 设备 3 -> 设备 0)设置优化的流水线并行调度。我们将模式 0 -> 1 -> 2 -> 3 称为前向边,将 3 -> 0 称为后向边。

# Imports and setup
import functools
import jax
from jax import sharding
from jax.experimental import mesh_utils
import jax.numpy as jnp
import jax.random

NUM_DEVICES = 4
NUM_MICROBATCHES = 5
NUM_CIRC_REPEATS = 2
CONTRACTING_DIM_SIZE = 4096
NON_CONTRACTING_DIM_SIZE = 8192
COMPUTE_INTENSITY = 32

# Creates a collective permute for the "forward edge".
# 0->1, 1->2, ... (N-2)->(N-1)
def shift_right(arr):
  padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
  # Use lax.slice to guarantee the gradient is a pad.
  return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)


# Creates a collective permute for the "back edge".
# (N-1)->0
def cycle_back(arr):
  padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
  return jax.lax.slice(
      jnp.pad(arr, padding),
      [NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
      (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
  )


def select_on_first_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
  return jnp.where(is_first_device, then_value, else_value)


def select_on_last_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_last_device = (
      jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
  )
  return jnp.where(is_last_device, then_value, else_value)


def select_on_first_cycle(i, then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_cycle = i < NUM_MICROBATCHES
  return jnp.where(is_first_cycle, then_value, else_value)


def while_body(carry, i):
  """Body of the pipeline while loop."""
  weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry

  # Read input data from input buffer.
  input_data = jax.lax.dynamic_slice(
      input_buffer,
      (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Collective permute on the "forward edge" shifts data to the next stage.
  fwd_edge_data = shift_right(fwd_edge_data)

  # Select compute argument based on device and pipeline cycle.
  compute_argument = select_on_first_device(
      select_on_first_cycle(i, input_data, bwd_edge_data),
      fwd_edge_data,
  ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))

  # A few matmuls to simulate compute.
  tmp = compute_argument
  for _ in range(COMPUTE_INTENSITY):
    tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
  compute_result = tmp.reshape(
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  )

  # Read data from buffer to pass it to the first device of the pipeline on the
  # "back edge".
  bwd_edge_data = jax.lax.dynamic_slice(
      output_buffer,
      (0, (1 + i) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Collective permute on the "back edge" passes data to the first device.
  bwd_edge_data = cycle_back(bwd_edge_data)

  # Update output buffer. We do this after reading from it to avoid the data
  # dependency.
  output_buffer = jax.lax.dynamic_update_slice(
      output_buffer,
      compute_result,
      (0, (2 + i) % NUM_MICROBATCHES, 0, 0),
  )

  fwd_edge_data = compute_result
  carry = (
      weights,
      input_buffer,
      output_buffer,
      fwd_edge_data,
      bwd_edge_data,
  )
  return carry, i


@functools.partial(jax.jit, static_argnames=["mesh"])
def entry_computation(weights, input_buffer, mesh):

  # Init output buffer.
  output_buffer = jnp.zeros_like(input_buffer)

  # Init dummy data for forward and backward edge passed through the while loop.
  dummy_data = jnp.zeros(
      shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  ).astype(jnp.float32)
  dummy_data = jax.device_put(
      dummy_data,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Start pipeline.
  carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
  num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
  carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
  _, _, output_buffer, _, _ = carry

  return output_buffer


def main(_):

  # Expect constant number of devices.
  assert NUM_DEVICES == jax.local_device_count()

  # Create mesh.
  mesh = sharding.Mesh(
      mesh_utils.create_device_mesh([NUM_DEVICES]),
      axis_names=["x"],
  )

  # Init weights.
  weights = 1.0 / CONTRACTING_DIM_SIZE
  weights = jax.lax.broadcast_in_dim(
      weights,
      shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
      broadcast_dimensions=(),
  )
  weights = jax.device_put(
      weights,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Init random input and replicate it across all devices.
  random_key = jax.random.key(0)
  input_buffer = jax.random.uniform(
      random_key,
      shape=(
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
  )
  input_buffer = jax.lax.broadcast_in_dim(
      input_buffer,
      shape=(
          NUM_DEVICES,
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
      broadcast_dimensions=[1, 2, 3],
  )
  input_buffer = jax.device_put(
      input_buffer,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Run computation.
  output_buffer = entry_computation(weights, input_buffer, mesh)
  print(f"output_buffer = \n{output_buffer}")

使用 psendprecv#

上面的 JAX 示例会转换为 collective-permute HLO 指令,这些指令通过 GPU 上的 ncclSendncclRecv 实现。对于希望更精细控制集合操作顺序的用户,他们可以直接使用 jax.lax.psendjax.lax.precv。从语法上讲,这两个函数与其 HLO 对应项类似。用户应注意,当单个 psendprecv 中的源-目标对形成循环时,以及当 psend 未与 precv 匹配,反之亦然时,程序将死锁。

如果设备通信模式中需要循环,可以通过确保 (1) 单个 psendprecv 函数的源-目标对不包含循环,并且 (2) 插入一个虚拟数据依赖项以使发送/接收对顺序化来避免死锁。在 psend`/precv 对之间无法调度集合操作,这只能通过 JAX 级别的 jax.lax.optimization_barrier 进行控制。文件 shard_map_test.py 中的测试用例 test_psend_precv_basic_with_no_deadlock_cycle 就是一个这样的例子。

上一节中的流水线并行示例使用了 --xla_gpu_experimental_pipeline_parallelism_opt_level XLA 标志。如果手动进行流水线化,同一程序可以使用 psendprecv 重写而无需该标志。

## same setup and imports
def while_body(carry, i):
  (
      weights,
      input_buffer,
      output_buffer,
      prev_compute_res,
      prev_stage_slice_fwd,
      prev_stage_slice_bwd,
  ) = carry

  # Read input data from input buffer.
  input_slice = jax.lax.dynamic_slice(
      input_buffer,
      (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # send_fwd
  fwd_send_token = jax.lax.psend(
      prev_compute_res,
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  # Select compute argument based on device and pipeline cycle
  compute_argument = select_on_first_device(
      select_on_first_cycle(i, input_slice, prev_stage_slice_bwd),
      prev_stage_slice_fwd,
  ).reshape((1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))

  tmp = compute_argument
  for _ in range(COMPUTE_INTENSITY):
    tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
  compute_result = tmp.reshape(
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  )

  buffer_slice_for_bwd_ppermute = jax.lax.dynamic_slice(
      output_buffer,
      (0, (i + 1) % NUM_MICROBATCHES, 0, 0),
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # make sure ppermute is scheduled after send_fwd
  buffer_slice_for_bwd_ppermute_after_send_fwd, _ = (
      jax.lax.optimization_barrier(
          (buffer_slice_for_bwd_ppermute, fwd_send_token)
      )
  )
  # ppermute_bwd
  ppermute_bwd_data = jax.lax.ppermute(
      buffer_slice_for_bwd_ppermute_after_send_fwd,
      axis_name="x",
      perm=[(3, 0)],
  )

  # make sure recv is scheduled after ppermute
  precv_token, _ = jax.lax.optimization_barrier(
      (jax.lax.create_token(), ppermute_bwd_data)
  )

  # recv_fwd, matches the send_fwd in the next iteration
  fwd_recv_data = jax.lax.precv(
      precv_token,
      out_shape=jax.ShapeDtypeStruct(
          input_slice.shape, input_slice.dtype
      ),
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )
  update_output_buffer = jax.lax.dynamic_update_slice(
      output_buffer,
      compute_result,
      (0, (i + 2) % NUM_MICROBATCHES, 0, 0),
  )
  carry = (
      weights,
      input_buffer,
      update_output_buffer,
      compute_result,
      fwd_recv_data,
      ppermute_bwd_data,
  )
  return carry, i


def entry_computation(
    weights, input_buffer, dummy_data, mesh
):

  # Init output buffer.
  output_buffer = jnp.zeros_like(input_buffer)

  # Start pipeline.
  dummy_slice_fwd = jax.lax.precv(
      jax.lax.create_token(),
      jax.ShapeDtypeStruct(dummy_data.shape, dummy_data.dtype),
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  carry = (
      weights,
      input_buffer,
      output_buffer,
      dummy_slice_fwd,
      dummy_data,
      dummy_data,
  )

  num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
  carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))

  _ = jax.lax.psend(
      carry[3],
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  _, _, output_buffer, _, _, _ = carry

  return output_buffer


def main(_):

  # Expect constant number of devices.
  assert NUM_DEVICES == jax.local_device_count()

  # Create mesh.
  mesh = Mesh(
      mesh_utils.create_device_mesh([NUM_DEVICES]),
      axis_names=["x"],
  )
  # Init weights.
  weights = 1.0 / CONTRACTING_DIM_SIZE
  weights = jax.lax.broadcast_in_dim(
      weights,
      shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
      broadcast_dimensions=(),
  )
  weights = jax.device_put(
      weights, NamedSharding(mesh, P("x"))
  )
  # Init input.
  random_key = jax.random.key(0)
  input_buffer = jax.random.uniform(
      random_key,
      shape=(
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
  )
  input_buffer = jax.lax.broadcast_in_dim(
      input_buffer,
      shape=(
          NUM_DEVICES,
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
      broadcast_dimensions=[1, 2, 3],
  )

  input_buffer = jax.device_put(
      input_buffer,
      NamedSharding(mesh, P("x")),
  )
  # Init dummy data for forward and backward edge passed through the while
  # loop.
  dummy_slice = jnp.zeros(
      shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  ).astype(jnp.float32)
  dummy_data = jax.device_put(
      dummy_slice,
      NamedSharding(mesh, P("x")),
  )

  entry = partial(entry_computation, mesh=mesh)

  output_buffer = jax.jit(
      jax.shard_map(
          entry,
          mesh=mesh,
          in_specs=P("x"),
          out_specs=P("x"),
          check_vma=False,
      )
  )(weights, input_buffer, dummy_data)
  print(f"output_buffer = \n{output_buffer}")

NCCL 标志#

这些 Nvidia NCCL 标志值可能对 Nvidia GPU 上的单主机多设备计算有用

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。这些标志似乎对多主机通信尚不适用。

多进程#

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 JIT 编译的计算。jax.distributed.initialize() API 在 SLURM 下运行时会自动理解该配置。然而,这只是一般经验法则,在您的用例中测试每个 GPU 一个进程和每个节点一个进程可能都有用。