GPU 性能提示#

本文档重点介绍神经网络工作负载的性能提示

Matmul 精度#

在最近的 GPU 世代中,例如 Nvidia A100 世代或更高版本,以 bfloat16 精度执行大多数计算可能是一个好主意。例如,如果使用 Flax,则使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例

XLA 性能标志#

注意

JAX-Toolbox 还有一个关于 NVIDIA XLA 性能标志的页面。

XLA 标志的存在和确切行为可能取决于 jaxlib 版本。

截至 jaxlib==0.4.18(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。其中一些与 GPU 之间的通信有关,因此仅在多设备上运行计算时才相关,而另一些则与每个设备上的代码生成有关。

其中一些可能在未来的版本中默认设置。

这些标志可以通过 XLA_FLAGS shell 环境变量设置。例如,我们可以将此添加到 Python 文件的顶部

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

有关更多示例,另请参阅 为 Nvidia GPU 上的 Pax 训练推荐的 XLA 标志

代码生成标志#

  • –xla_gpu_triton_gemm_any 对其支持的任何 GEMM 使用基于 Triton 的 GEMM(matmul)发射器。默认值为 False。

通信技巧#

自动和手动 PGLE#

配置文件引导的延迟估计器 (PGLE) 工作流程测量计算和集体操作的实际运行时间,并将配置文件信息反馈到 XLA 编译器,以获得更好的调度决策。

配置文件引导的延迟估计器可以手动或自动使用。在自动模式下,JAX 将收集配置文件信息并在单次运行中重新编译模块。而在手动模式下,您需要运行任务两次,第一次收集和保存配置文件,第二次使用提供的数据编译和运行。

重要提示:JAX 性能分析器(以下记录的 PGLE 工作流程都使用它)不能与 NVIDIA Nsight Systems 性能分析器共存。通过使用 JAX 编译缓存可以避免此限制,如下所述。

自动 PGLE#

可以通过设置以下环境变量来启用自动 PGLE

必需

export JAX_ENABLE_PGLE=true

# For JAX version <= 0.5.0 make sure to include:
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"

可选

export JAX_PGLE_PROFILING_RUNS=3
export JAX_PGLE_AGGREGATION_PERCENTILE=85

# Right now the auto PGLE profile collection doesn't work with command buffer.
# If the command buffer is enabled, Auto PGLE will disable it during profile
# colletion and enable it back after the recompilation. If you need to have a
# consistent command buffer logic with and with PGLE profile you can disable it
# manually:
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"

或者在 JAX 中,可以将其设置为以下内容

import jax
from jax._src import config

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  # Run with the profiler collecting performance information.
  train_step()
  # Automatically re-compile with PGLE profile results
  train_step()
  ...

您可以通过更改 JAX_PGLE_PROFILING_RUNS 来控制用于收集配置文件数据的重新运行次数。增加此参数将导致更好的配置文件信息,但也会增加非优化训练步骤的数量。

如果步骤之间的性能过于嘈杂而无法过滤掉不相关的度量,则降低 JAX_PGLE_AGGREGATION_PERCENTILE 参数可能会有所帮助。

注意: 自动 PGLE 不适用于预编译模块。由于 JAX 需要在执行期间重新编译模块,因此自动 PGLE 既不适用于 AoT,也不适用于以下情况

import jax
from jax._src import config

train_step_compiled = train_step().lower().compile()

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  train_step_compiled()
  # No effect since module was pre-compiled.
  train_step_compiled()

使用 AutoPGLE 时收集 NVIDIA Nsight Systems 配置文件#

jax#24910 (JAX v0.5.1 及更高版本) 添加了一个新的 JAX 配置选项 JAX_COMPILATION_CACHE_EXPECT_PGLE,它告诉 JAX 尝试从持久编译缓存加载 PGLE 优化的编译函数。

这允许一个两步过程,其中第一步将 PGLE 优化的函数写入缓存

export JAX_ENABLE_COMPILATION_CACHE=yes          # not strictly needed, on by default
export JAX_COMPILATION_CACHE_DIR=/root/jax_cache
JAX_ENABLE_PGLE=yes python my-model.py

第二步使用 Nsight Systems 并从缓存加载 PGLE 优化的函数

JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py

另请参阅 此页面,以获取有关持久编译缓存和可能陷阱的更多信息。

手动 PGLE#

如果您仍然想使用手动配置文件引导的延迟估计器,则 XLA/GPU 中的工作流程是

    1. 运行您的工作负载一次,启用异步集体操作和延迟隐藏调度器。

您可以通过设置以下内容来做到这一点

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
    1. 通过使用 JAX 性能分析器收集和后处理配置文件,将提取的指令延迟保存到二进制 protobuf 文件中。

import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler

# Define your profile directory
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)

# run your workflow
# for i in range(10):
#   train_step()

# Stop trace
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)

# Post process the profile
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))

# Save the profile proto to a file.
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)

在此步骤之后,您将在代码中打印的 rundir 下获得一个 profile.pb 文件。

    1. 再次运行工作负载,将该文件馈送到编译中。

您需要将 profile.pb 文件传递给 --xla_gpu_pgle_profile_file_or_directory_path 标志。

 export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"

要在 XLA 中启用日志记录并检查配置文件是否良好,请将日志记录级别设置为包含 INFO

export TF_CPP_MIN_LOG_LEVEL=0

运行实际的工作流程,如果您在运行日志中找到这些日志记录,则意味着性能分析器已在延迟隐藏调度器中使用

2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator

标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效地将异步通信与计算重叠。默认值为 False。

  • –xla_gpu_memory_limit_slop_factor 此标志充当应用于总可用内存的乘数,创建一个阈值,该阈值指导延迟隐藏调度器 (LHS) 在平衡内存减少和延迟隐藏优化方面。默认值为 95。

    此因子有效地为编译器传递建立内存限制,确定调度器何时应优先考虑

    1. 内存减少:当内存使用量接近或超过计算出的阈值时。

    2. 延迟隐藏:当内存使用量低于阈值时,允许更积极的优化,这些优化可能会暂时增加内存使用量,但会提高整体性能。

    通过调整此因子,用户可以微调内存效率和性能优化之间的权衡。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志调整何时将多个小的 AllGather/ReduceScatter/AllReduce 组合成一个大的 AllGather/ReduceScatter/AllReduce,以减少花费在跨设备通信上的时间。例如,对于基于 Transformer 的工作负载上的 AllGather/ReduceScatter 阈值,请考虑将其调整得足够高,以便至少组合 Transformer 层的权重 AllGather/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

GPU 上的流水线并行#

XLA 实现了基于 SPMD 的流水线并行优化。这是一种扩展技术,其中前向和后向传递被拆分为多个流水线阶段。每个设备(或设备组)处理前一个流水线阶段(或流水线输入)的结果,并将其部分结果发送到下一个阶段,直到到达流水线的末端。当计算延迟大于通信时,此优化效果最佳。在编译时,将重新排列操作以使通信与计算重叠。

为了获得优化的调度,我们推荐以下 XLA 标志

--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''
--xla_disable_hlo_passes=collective-permute-motion
--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE

以下 JAX 示例演示了一种模式,其中通信操作被安排为与计算重叠。在此示例中,我们将说明如何使用 4 个 GPU 设置优化的流水线并行调度,这些 GPU 形成一个通信环(设备 0 -> 设备 1 -> 设备 2 -> 设备 3 -> 设备 0)。我们将模式 0 -> 1 -> 2 -> 3 称为前向边,3 -> 0 称为后向边。

# Imports and setup
import functools
import jax
from jax import sharding
from jax.experimental import mesh_utils
import jax.numpy as jnp
import jax.random

NUM_DEVICES = 4
NUM_MICROBATCHES = 5
NUM_CIRC_REPEATS = 2
CONTRACTING_DIM_SIZE = 4096
NON_CONTRACTING_DIM_SIZE = 8192
COMPUTE_INTENSITY = 32

# Creates a collective permute for the "forward edge".
# 0->1, 1->2, ... (N-2)->(N-1)
def shift_right(arr):
  padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
  # Use lax.slice to guarantee the gradient is a pad.
  return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)


# Creates a collective permute for the "back edge".
# (N-1)->0
def cycle_back(arr):
  padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
  return jax.lax.slice(
      jnp.pad(arr, padding),
      [NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
      (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
  )


def select_on_first_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
  return jnp.where(is_first_device, then_value, else_value)


def select_on_last_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_last_device = (
      jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
  )
  return jnp.where(is_last_device, then_value, else_value)


def select_on_first_cycle(i, then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_cycle = i < NUM_MICROBATCHES
  return jnp.where(is_first_cycle, then_value, else_value)


def while_body(carry, i):
  """Body of the pipeline while loop."""
  weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry

  # Read input data from input buffer.
  input_data = jax.lax.dynamic_slice(
      input_buffer,
      (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Collective permute on the "forward edge" shifts data to the next stage.
  fwd_edge_data = shift_right(fwd_edge_data)

  # Select compute argument based on device and pipeline cycle.
  compute_argument = select_on_first_device(
      select_on_first_cycle(i, input_data, bwd_edge_data),
      fwd_edge_data,
  ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))

  # A few matmuls to simulate compute.
  tmp = compute_argument
  for _ in range(COMPUTE_INTENSITY):
    tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
  compute_result = tmp.reshape(
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  )

  # Read data from buffer to pass it to the first device of the pipeline on the
  # "back edge".
  bwd_edge_data = jax.lax.dynamic_slice(
      output_buffer,
      (0, (1 + i) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Colelctive permute on the "back edge" passes data to the first device.
  bwd_edge_data = cycle_back(bwd_edge_data)

  # Update output buffer. We do this after reading from it to avoid the data
  # dependency.
  output_buffer = jax.lax.dynamic_update_slice(
      output_buffer,
      compute_result,
      (0, (2 + i) % NUM_MICROBATCHES, 0, 0),
  )

  fwd_edge_data = compute_result
  carry = (
      weights,
      input_buffer,
      output_buffer,
      fwd_edge_data,
      bwd_edge_data,
  )
  return carry, i


@functools.partial(jax.jit, static_argnames=["mesh"])
def entry_computation(weights, input_buffer, mesh):

  # Init output buffer.
  output_buffer = jnp.zeros_like(input_buffer)

  # Init dummy data for forward and backward edge passed through the while loop.
  dummy_data = jnp.zeros(
      shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  ).astype(jnp.float32)
  dummy_data = jax.device_put(
      dummy_data,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("the_one_and_only_axis")
      ),
  )

  # Start pipeline.
  carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
  num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
  carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
  _, _, output_buffer, _, _ = carry

  return output_buffer


def main(_):

  # Expect constant number of devices.
  assert NUM_DEVICES == jax.local_device_count()

  # Create mesh.
  mesh = sharding.Mesh(
      mesh_utils.create_device_mesh([NUM_DEVICES]),
      axis_names=["the_one_and_only_axis"],
  )

  # Init weights.
  weights = 1.0 / CONTRACTING_DIM_SIZE
  weights = jax.lax.broadcast_in_dim(
      weights,
      shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
      broadcast_dimensions=(),
  )
  weights = jax.device_put(
      weights,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("the_one_and_only_axis")
      ),
  )

  # Init random input and replicate it across all devices.
  random_key = jax.random.key(0)
  input_buffer = jax.random.uniform(
      random_key,
      shape=(
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
  )
  input_buffer = jax.lax.broadcast_in_dim(
      input_buffer,
      shape=(
          NUM_DEVICES,
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
      broadcast_dimensions=[1, 2, 3],
  )
  input_buffer = jax.device_put(
      input_buffer,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("the_one_and_only_axis")
      ),
  )

  # Run computation.
  output_buffer = entry_computation(weights, input_buffer, mesh)
  print(f"output_buffer = \n{output_buffer}")

NCCL 标志#

这些 Nvidia NCCL 标志值可能对 Nvidia GPU 上的单主机多设备计算有用

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。这些标志对于多主机通信似乎尚无用。

多进程#

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。jax.distributed.initialize() API 在 SLURM 下运行时会自动理解该配置。但是,这只是一条经验法则,在您的用例中测试每个 GPU 一个进程和每个节点一个进程可能很有用。