GPU 性能提示#

本文档重点介绍神经网络工作负载的性能提示

矩阵乘法精度#

在最新的 GPU 上,例如 Nvidia A100 或更新的 GPU,使用 bfloat16 精度执行大部分计算可能是一个好主意。例如,如果使用 Flax,请使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例

XLA 性能标志#

注意

JAX-Toolbox 还有一个关于 NVIDIA XLA 性能标志的页面。

XLA 标志的存在和确切行为可能取决于 jaxlib 版本。

截至 jaxlib==0.4.18(于 2023 年 10 月 6 日发布),设置这些 XLA 标志可以提高性能。其中一些与 GPU 之间的通信有关,因此仅在多设备上运行计算时才相关,而另一些则与每个设备的代码生成有关。

其中一些可能在未来的版本中默认设置。

可以通过 XLA_FLAGS shell 环境变量设置这些标志。例如,我们可以在 Python 文件顶部添加以下内容

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

有关更多示例,另请参阅 Nvidia GPU 上 Pax 训练推荐的 XLA 标志

代码生成标志#

  • –xla_gpu_triton_gemm_any 对其支持的任何 GEMM(矩阵乘法)使用基于 Triton 的 GEMM 发射器。默认值为 False。

通信提示#

自动和手动 PGLE#

Profile Guided Latency Estimator (PGLE) 工作流测量计算和集合通信的实际运行时间,然后将配置文件信息反馈给 XLA 编译器以做出更好的调度决策。

Profile Guided Latency Estimator 可以手动或自动使用。在自动模式下,JAX 将在单次运行中收集配置文件信息并重新编译模块。而在手动模式下,您需要运行任务两次,第一次收集并保存配置文件,第二次使用提供的数据进行编译和运行。

重要提示:JAX 剖析器(下方文档中两种 PGLE 工作流都使用的)不能与 NVIDIA Nsight Systems 剖析器共存。可以通过使用 JAX 编译缓存来避免此限制,如下所述。

自动 PGLE#

可以通过设置以下环境变量来开启自动 PGLE

必填

export JAX_ENABLE_PGLE=true

# For JAX version <= 0.5.0 make sure to include:
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"

可选

export JAX_PGLE_PROFILING_RUNS=3
export JAX_PGLE_AGGREGATION_PERCENTILE=85

# Right now the auto PGLE profile collection doesn't work with command buffer.
# If the command buffer is enabled, Auto PGLE will disable it during profile
# collection and enable it back after the recompilation. If you need to have a
# consistent command buffer logic with and with PGLE profile you can disable it
# manually:
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"

或者在 JAX 中可以按以下方式设置

import jax
from jax._src import config

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  # Run with the profiler collecting performance information.
  train_step()
  # Automatically re-compile with PGLE profile results
  train_step()
  ...

您可以通过更改 JAX_PGLE_PROFILING_RUNS 来控制用于收集配置文件数据的重新运行次数。增加此参数将提供更好的配置文件信息,但也会增加非优化训练步骤的数量。

在性能在一个步骤之间过于嘈杂而无法过滤掉不相关测量值的情况下,减小 JAX_PGLE_AGGREGATION_PERCENTILE 参数可能有所帮助。

注意: 自动 PGLE 不适用于预编译模块。由于 JAX 需要在执行期间重新编译模块,因此自动 PGLE 不适用于 AoT 也不适用于以下情况

import jax
from jax._src import config

train_step_compiled = train_step().lower().compile()

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  train_step_compiled()
  # No effect since module was pre-compiled.
  train_step_compiled()

使用 AutoPGLE 时收集 NVIDIA Nsight Systems 配置文件#

jax#24910(JAX v0.5.1 及更高版本)添加了一个新的 JAX 配置选项 JAX_COMPILATION_CACHE_EXPECT_PGLE,它告诉 JAX 尝试从持久编译缓存中加载 PGLE 优化的编译函数。

这允许一个两步过程,第一步将 PGLE 优化的函数写入缓存

export JAX_ENABLE_COMPILATION_CACHE=yes          # not strictly needed, on by default
export JAX_COMPILATION_CACHE_DIR=/root/jax_cache
JAX_ENABLE_PGLE=yes python my-model.py

第二步使用 Nsight Systems 并从缓存加载 PGLE 优化的函数

JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py

有关持久编译缓存及其潜在陷阱的更多信息,请参阅 此页面

手动 PGLE#

如果您仍然想使用手动 Profile Guided Latency Estimator,XLA/GPU 中的工作流是

    1. 运行一次工作负载,启用异步集合通信和延迟隐藏调度器。

您可以通过设置来实现

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
    1. 使用 JAX 剖析器收集和后处理配置文件,将提取的指令延迟保存到二进制 protobuf 文件中。

import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler

# Define your profile directory
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)

# run your workflow
# for i in range(10):
#   train_step()

# Stop trace
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)

# Post process the profile
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))

# Save the profile proto to a file.
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)

在此步骤之后,您将在代码中打印的 rundir 下获得一个 profile.pb 文件。

    1. 再次运行工作负载,将该文件输入编译。

您需要将 profile.pb 文件传递给 --xla_gpu_pgle_profile_file_or_directory_path 标志。

 export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"

要在 XLA 中启用日志记录并检查配置文件是否良好,请将日志级别设置为包含 INFO

export TF_CPP_MIN_LOG_LEVEL=0

运行实际工作负载,如果您在运行日志中发现这些日志记录,则表示剖析器已用于延迟隐藏调度器

2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator

标志#

  • –xla_gpu_enable_latency_hiding_scheduler 此标志启用延迟隐藏调度器,以有效地将异步通信与计算重叠。默认值为 False。

  • –xla_gpu_memory_limit_slop_factor 此标志用作总可用内存的乘数,创建一个阈值,用于指导延迟隐藏调度器 (LHS) 在内存缩减和延迟隐藏优化之间进行平衡。默认值为 95。

    该因子有效地为编译器传递设置了内存限制,决定了调度器何时应优先考虑

    1. 内存缩减:当内存使用量接近或超过计算阈值时。

    2. 延迟隐藏:当内存使用量低于阈值时,允许更积极的优化,这些优化可能会暂时增加内存使用量,但会提高整体性能。

    通过调整此因子,用户可以微调内存效率和性能优化之间的权衡。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的 AllGather/ReduceScatter/AllReduce 合并为一个大的 AllGather/ReduceScatter/AllReduce,以减少跨设备通信花费的时间。例如,对于基于 Transformer 的工作负载的 AllGather/ReduceScatter 阈值,请考虑将其调高,以便至少合并一个 Transformer 层的权重 AllGather/ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。

GPU 上的流水线并行#

使用 XLA 标志#

XLA 实现基于 SPMD 的流水线并行优化。这是一种扩展技术,其中前向和后向传递被分成多个流水线阶段。每个设备(或设备组)处理前一个流水线阶段的结果(或流水线输入),并将其部分结果发送到下一个阶段,直到到达流水线末尾。当计算延迟大于通信延迟时,此优化效果最佳。在编译时,操作将被重新排列以使通信与计算重叠。

为了获得优化的计划,我们推荐这些 XLA 标志

--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''
--xla_disable_hlo_passes=collective-permute-motion
--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE

以下 JAX 示例演示了一种将通信操作安排为与计算重叠的模式。在此示例中,我们将说明如何使用 4 个 GPU 设置优化的流水线并行计划,这些 GPU 形成一个通信环(设备 0 -> 设备 1 -> 设备 2 -> 设备 3 -> 设备 0)。我们将 0 -> 1 -> 2 -> 3 模式称为前向边,将 3 -> 0 称为后向边。

# Imports and setup
import functools
import jax
from jax import sharding
from jax.experimental import mesh_utils
import jax.numpy as jnp
import jax.random

NUM_DEVICES = 4
NUM_MICROBATCHES = 5
NUM_CIRC_REPEATS = 2
CONTRACTING_DIM_SIZE = 4096
NON_CONTRACTING_DIM_SIZE = 8192
COMPUTE_INTENSITY = 32

# Creates a collective permute for the "forward edge".
# 0->1, 1->2, ... (N-2)->(N-1)
def shift_right(arr):
  padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
  # Use lax.slice to guarantee the gradient is a pad.
  return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)


# Creates a collective permute for the "back edge".
# (N-1)->0
def cycle_back(arr):
  padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
  return jax.lax.slice(
      jnp.pad(arr, padding),
      [NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
      (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
  )


def select_on_first_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
  return jnp.where(is_first_device, then_value, else_value)


def select_on_last_device(then_value, else_value):
  assert then_value.shape == else_value.shape
  is_last_device = (
      jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
  )
  return jnp.where(is_last_device, then_value, else_value)


def select_on_first_cycle(i, then_value, else_value):
  assert then_value.shape == else_value.shape
  is_first_cycle = i < NUM_MICROBATCHES
  return jnp.where(is_first_cycle, then_value, else_value)


def while_body(carry, i):
  """Body of the pipeline while loop."""
  weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry

  # Read input data from input buffer.
  input_data = jax.lax.dynamic_slice(
      input_buffer,
      (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Collective permute on the "forward edge" shifts data to the next stage.
  fwd_edge_data = shift_right(fwd_edge_data)

  # Select compute argument based on device and pipeline cycle.
  compute_argument = select_on_first_device(
      select_on_first_cycle(i, input_data, bwd_edge_data),
      fwd_edge_data,
  ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))

  # A few matmuls to simulate compute.
  tmp = compute_argument
  for _ in range(COMPUTE_INTENSITY):
    tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
  compute_result = tmp.reshape(
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  )

  # Read data from buffer to pass it to the first device of the pipeline on the
  # "back edge".
  bwd_edge_data = jax.lax.dynamic_slice(
      output_buffer,
      (0, (1 + i) % NUM_MICROBATCHES, 0, 0),
      (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # Collective permute on the "back edge" passes data to the first device.
  bwd_edge_data = cycle_back(bwd_edge_data)

  # Update output buffer. We do this after reading from it to avoid the data
  # dependency.
  output_buffer = jax.lax.dynamic_update_slice(
      output_buffer,
      compute_result,
      (0, (2 + i) % NUM_MICROBATCHES, 0, 0),
  )

  fwd_edge_data = compute_result
  carry = (
      weights,
      input_buffer,
      output_buffer,
      fwd_edge_data,
      bwd_edge_data,
  )
  return carry, i


@functools.partial(jax.jit, static_argnames=["mesh"])
def entry_computation(weights, input_buffer, mesh):

  # Init output buffer.
  output_buffer = jnp.zeros_like(input_buffer)

  # Init dummy data for forward and backward edge passed through the while loop.
  dummy_data = jnp.zeros(
      shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  ).astype(jnp.float32)
  dummy_data = jax.device_put(
      dummy_data,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Start pipeline.
  carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
  num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
  carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
  _, _, output_buffer, _, _ = carry

  return output_buffer


def main(_):

  # Expect constant number of devices.
  assert NUM_DEVICES == jax.local_device_count()

  # Create mesh.
  mesh = sharding.Mesh(
      mesh_utils.create_device_mesh([NUM_DEVICES]),
      axis_names=["x"],
  )

  # Init weights.
  weights = 1.0 / CONTRACTING_DIM_SIZE
  weights = jax.lax.broadcast_in_dim(
      weights,
      shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
      broadcast_dimensions=(),
  )
  weights = jax.device_put(
      weights,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Init random input and replicate it across all devices.
  random_key = jax.random.key(0)
  input_buffer = jax.random.uniform(
      random_key,
      shape=(
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
  )
  input_buffer = jax.lax.broadcast_in_dim(
      input_buffer,
      shape=(
          NUM_DEVICES,
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
      broadcast_dimensions=[1, 2, 3],
  )
  input_buffer = jax.device_put(
      input_buffer,
      sharding.NamedSharding(
          mesh, sharding.PartitionSpec("x")
      ),
  )

  # Run computation.
  output_buffer = entry_computation(weights, input_buffer, mesh)
  print(f"output_buffer = \n{output_buffer}")

使用 psendprecv#

上面的 JAX 示例会降低到 collective-permute HLO 指令,这些指令通过 GPU 上的 ncclSendncclRecv 实现。对于希望更精细地控制集合通信顺序的用户,他们可以直接使用 jax.lax.psendjax.lax.precv。在语法上,这两个函数与其 HLO 对应函数类似。用户应注意,当 *单个* psendprecv 中的源-目标对形成循环时,以及当 psendprecv 不匹配(反之亦然)时,程序将死锁。

如果设备通信模式需要循环,可以通过确保 (1) 任何单个 psendprecv 函数的源-目标对不包含循环,并且 (2) 插入一个虚假数据依赖项来顺序化 send/recv 对来避免死锁。在 psend/precv 对之间不能安排任何集合通信,这只能通过 JAX 级别的 jax.lax.optimization_barrier 来控制。shard_map_test.py 文件中的测试用例 test_psend_precv_basic_with_no_deadlock_cycle 是一个此类示例。

上一节中的流水线并行示例使用了 --xla_gpu_experimental_pipeline_parallelism_opt_level XLA 标志。如果不使用该标志,并且手动进行流水线处理,则相同的程序可以使用 psendprecv 重写。

## same setup and imports
def while_body(carry, i):
  (
      weights,
      input_buffer,
      output_buffer,
      prev_compute_res,
      prev_stage_slice_fwd,
      prev_stage_slice_bwd,
  ) = carry

  # Read input data from input buffer.
  input_slice = jax.lax.dynamic_slice(
      input_buffer,
      (0, (i + 0) % NUM_MICROBATCHES, 0, 0),
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # send_fwd
  fwd_send_token = jax.lax.psend(
      prev_compute_res,
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  # Select compute argument based on device and pipeline cycle
  compute_argument = select_on_first_device(
      select_on_first_cycle(i, input_slice, prev_stage_slice_bwd),
      prev_stage_slice_fwd,
  ).reshape((1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))

  tmp = compute_argument
  for _ in range(COMPUTE_INTENSITY):
    tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
  compute_result = tmp.reshape(
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  )

  buffer_slice_for_bwd_ppermute = jax.lax.dynamic_slice(
      output_buffer,
      (0, (i + 1) % NUM_MICROBATCHES, 0, 0),
      (1, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
  )

  # make sure ppermute is scheduled after send_fwd
  buffer_slice_for_bwd_ppermute_after_send_fwd, _ = (
      jax.lax.optimization_barrier(
          (buffer_slice_for_bwd_ppermute, fwd_send_token)
      )
  )
  # ppermute_bwd
  ppermute_bwd_data = jax.lax.ppermute(
      buffer_slice_for_bwd_ppermute_after_send_fwd,
      axis_name="x",
      perm=[(3, 0)],
  )

  # make sure recv is scheduled after ppermute
  precv_token, _ = jax.lax.optimization_barrier(
      (jax.lax.create_token(), ppermute_bwd_data)
  )

  # recv_fwd, matches the send_fwd in the next iteration
  fwd_recv_data = jax.lax.precv(
      precv_token,
      out_shape=jax.ShapeDtypeStruct(
          input_slice.shape, input_slice.dtype
      ),
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )
  update_output_buffer = jax.lax.dynamic_update_slice(
      output_buffer,
      compute_result,
      (0, (i + 2) % NUM_MICROBATCHES, 0, 0),
  )
  carry = (
      weights,
      input_buffer,
      update_output_buffer,
      compute_result,
      fwd_recv_data,
      ppermute_bwd_data,
  )
  return carry, i


def entry_computation(
    weights, input_buffer, dummy_data, mesh
):

  # Init output buffer.
  output_buffer = jnp.zeros_like(input_buffer)

  # Start pipeline.
  dummy_slice_fwd = jax.lax.precv(
      jax.lax.create_token(),
      jax.ShapeDtypeStruct(dummy_data.shape, dummy_data.dtype),
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  carry = (
      weights,
      input_buffer,
      output_buffer,
      dummy_slice_fwd,
      dummy_data,
      dummy_data,
  )

  num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
  carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))

  _ = jax.lax.psend(
      carry[3],
      axis_name="x",
      perm=[(0, 1), (1, 2), (2, 3)],
  )

  _, _, output_buffer, _, _, _ = carry

  return output_buffer


def main(_):

  # Expect constant number of devices.
  assert NUM_DEVICES == jax.local_device_count()

  # Create mesh.
  mesh = Mesh(
      mesh_utils.create_device_mesh([NUM_DEVICES]),
      axis_names=["x"],
  )
  # Init weights.
  weights = 1.0 / CONTRACTING_DIM_SIZE
  weights = jax.lax.broadcast_in_dim(
      weights,
      shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
      broadcast_dimensions=(),
  )
  weights = jax.device_put(
      weights, NamedSharding(mesh, P("x"))
  )
  # Init input.
  random_key = jax.random.key(0)
  input_buffer = jax.random.uniform(
      random_key,
      shape=(
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
  )
  input_buffer = jax.lax.broadcast_in_dim(
      input_buffer,
      shape=(
          NUM_DEVICES,
          NUM_MICROBATCHES,
          CONTRACTING_DIM_SIZE,
          NON_CONTRACTING_DIM_SIZE,
      ),
      broadcast_dimensions=[1, 2, 3],
  )

  input_buffer = jax.device_put(
      input_buffer,
      NamedSharding(mesh, P("x")),
  )
  # Init dummy data for forward and backward edge passed through the while
  # loop.
  dummy_slice = jnp.zeros(
      shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
  ).astype(jnp.float32)
  dummy_data = jax.device_put(
      dummy_slice,
      NamedSharding(mesh, P("x")),
  )

  entry = partial(entry_computation, mesh=mesh)

  output_buffer = jax.jit(
      jax.shard_map(
          entry,
          mesh=mesh,
          in_specs=P("x"),
          out_specs=P("x"),
          check_vma=False,
      )
  )(weights, input_buffer, dummy_data)
  print(f"output_buffer = \n{output_buffer}")

NCCL 标志#

这些 Nvidia NCCL 标志值可能对 Nvidia GPU 上的单主机多设备计算有用

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

这些 NCCL 标志可以提高单主机通信速度。目前这些标志似乎对多主机通信没有用。

多进程#

我们建议使用每个 GPU 一个进程,而不是每个节点一个进程。在某些情况下,这可以加速 JIT 编译的计算。在 SLURM 下运行时,jax.distributed.initialize() API 将自动理解该配置。但这只是一个经验法则,在您的具体用例中测试每个 GPU 一个进程和每个节点一个进程可能都有益处。