伪随机数生成#

Pallas TPU 实现了多种 API,用于在内核(kernel)内部生成伪随机数,在可移植性和效率之间有不同的权衡。为了获得最大的可移植性,建议直接使用 jax.random 函数。Pallas 还公开了 TPU 上内置的硬件 PRNG,这是计算速度最快的方法,但其底层实现可能因硬件代际而异。

使用 jax.random API#

Pallas 支持 jax.random API 中的一部分操作。当给定相同的密钥(key)时,这些函数保证能产生与在 Pallas 之外的 JAX 中调用它们时按位(bitwise)一致的结果。仅支持 threefry2x32 密钥。

目前支持以下随机采样函数

支持以下工具函数

PRNG 密钥可以使用 jax.random.key() 在内核内部生成。然而,更常见的情况是密钥从调用者传递到内核中。在这种情况下,可以通过 VMEM 将密钥传递给内核,如下所示:

def body(key_ref, o_ref):
  key = key_ref[...]
  o_ref[...] = jax_random.uniform(
      key, shape=o_ref[...].shape, minval=0.0, maxval=1.0
  )

threefry_key = jax_random.key(0, impl="threefry2x32")

# We generate a threefry key outside of the kernel and pass it in via VMEM.
result = pl.pallas_call(
    body,
    in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)],
    out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32)
)(threefry_key)

注意

关于性能方面,在内核内部生成随机数有助于减少内存带宽的使用,因为传递密钥比传递大型随机数数组开销更小。但是,threefry2x32 是一种向量密集型算法,涉及数十次链式位运算。这可能会成为瓶颈并导致加速器利用率降低,因为它没有利用绝大多数 FLOP/s 所在的矩阵乘法单元 (MXU)。

使用硬件 PRNG#

TPU 在硬件中原生实现了一种序列式(而非基于计数器的)PRNG,其计算速度远快于使用软件实现的 PRNG(如 threefry2x32)。然而,JAX 的随机 API 假设使用的是无状态、基于计数器的 PRNG,因此 Pallas 引入了自己的有状态 PRNG API 来提供等效功能。

警告

硬件 PRNG 的底层实现因 TPU 代际而异,因此最佳做法是不依赖其具体行为。对于软件实现的更稳定的 PRNG,建议使用 threefry2x32 实现。

有状态随机数生成#

在有状态模式下使用 Pallas PRNG 是生成随机数最原生且高效的方法。首先,应使用 pltpu.prng_seed(N) 设置 PRNG 种子,其中 N 为整数种子。

之后,您可以调用任意数量的有状态采样函数,它们与对应的 JAX 版本等效,但不需要 key 参数:

生成任何随机数都会更新 PRNG 的内部状态,后续调用将生成不同的数字。与 JAX 不同,无需对密钥进行 split(拆分)或 fold_in(折叠)操作,也无需将其传递给采样函数。

例如,以下内核生成一组从 0 到 1 的均匀分布数字:

from jax.experimental.pallas import tpu as pltpu

def kernel_body(o_ref):
  pltpu.prng_seed(0)
  o_ref[...] = pltpu.stateful_uniform(shape=o_ref.shape, minval=0.0, maxval=1.0)

pl.pallas_call(kernel_body,
               out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32))

请注意,在带有网格(grid)的内核中,种子应仅在第一次迭代时设置,否则由于种子重置,每个程序实例生成的随机数将完全相同。

无状态生成#

Pallas 提供了一种介于前述“无状态 API”和 jax.random API 之间的中间 API,允许您以无状态方式使用硬件 PRNG。为此,请通过 pltpu.to_pallas_key(key) 将 JAX 密钥转换为特殊的 Pallas 类型密钥,并通过 SMEM 将其传递给内核。一旦密钥在内核内解引用,它就可以传递给 jax.random 中支持的采样函数以生成随机数。与无状态 API 相比,每次调用随机数生成器时,计算和设置种子都会产生额外开销。

例如,以下内核使用硬件 PRNG 抽取均匀分布数字:

def body(key_ref, o_ref):
  o_ref[...] = jax.random.uniform(
      key_ref[...], shape=o_ref[...].shape
  )

rbg_key = jax_random.key(0, impl="threefry2x32")
key = pltpu.to_pallas_key(rbg_key)
o_shape = jax.ShapeDtypeStruct((8, 128), dtype)
result = pl.pallas_call(
    body,
    in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
    out_shape=o_shape,
)(key)

对于具有网格的更大内核,可以在 program_id 上使用 jax.random.fold_in(),为每个程序实例生成唯一密钥。

块不变采样#

块不变采样是一种以块形式生成随机数的方法,该方法对所使用的块大小和迭代顺序具有不变性。例如,您可能希望在两个内核(如前向传播和反向传播)之间生成完全相同的随机数集,但这两个内核可能在调优后选择了不同的块大小。

Pallas 提供了一个辅助函数(pltpu.sample_block),允许用户保证在不同的块和网格设置下能够抽取到相同的随机数。第一步是选择一个 tile_size,该瓦片大小需能整除您希望保持不变的所有块大小。例如,tile_size=(16, 128) 适用于块大小为 (32, 128)(16, 256) 的情况。瓦片越大,采样过程越高效,因此所有潜在块大小的最大公约数是最佳选择。

接下来,使用以下参数调用 pltpu.sample_block

pltpu.sample_block(
  sampler_function,  # A JAX random function, such as `jax.random.uniform`.
  global_key,  # A global key shared across all blocks.
  block_size,  # The local block size to generate.
  tile_size,  # The tile size.
  total_size,  # The total shape of the generated array across all blocks.
  block_index,  # The block index into total_size. Usually this is the current program instance.
  **sampler_kwargs  # Keyword arguments to sampler_function
)

例如,以下代码片段在 (16, 128) 块形状和 (32, 256) 块形状(且带有转置网格迭代顺序)下生成相同的数字:

def make_kernel_body(index_map):
  def body(key_ref, o_ref):
    key = key_ref[...]
    samples = pltpu.sample_block(
        jax.random.uniform,
        key,
        block_size=o_ref[...].shape,
        tile_size=(16, 128),
        total_size=(64, 512),
        block_index=index_map(pl.program_id(0), pl.program_id(1)),
        minval=0.0,
        maxval=1.0)
    o_ref[...] = samples
  return body

global_key = pltpu.to_pallas_key(jax_random.key(0))
o_shape = jnp.ones((64, 512), dtype=jnp.float32)
key_spec = pl.BlockSpec(memory_space=pltpu.SMEM)
out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j))
result_16x128 = pl.pallas_call(
    make_kernel_body(index_map=lambda i, j: (i, j)),
    out_shape=o_shape,
    in_specs=[key_spec],
    out_specs=out_spec,
    grid=(4, 4),
)(global_key)

out_spec = pl.BlockSpec((32, 256), lambda i, j: (j, i))
result_32x256_transposed = pl.pallas_call(
    make_kernel_body(index_map=lambda i, j: (j, i)),
    in_specs=[key_spec],
    out_shape=o_shape,
    out_specs=out_spec,
    grid=(2, 2),
)(global_key)