伪随机数生成#

Pallas TPU 在内核中实现了几种用于生成伪随机数的 API,在可移植性和效率方面具有不同的权衡。为了最大程度的可移植性,请考虑直接使用 jax.random 函数。Pallas 还公开了 TPU 上包含的硬件 PRNG,其计算速度最快,但底层实现可能因硬件代际而异。

使用 jax.random API#

Pallas 支持 jax.random API 中的一部分操作。当给定相同的密钥时,这些函数保证会产生与在 Pallas 外部调用 JAX 中的这些函数相同的按位相等的结果。只支持 threefry2x32 密钥。

当前支持以下随机采样函数

支持以下实用函数

PRNG 密钥可以通过 jax.random.key() 在内核中生成。但是,更常见的情况是将密钥从调用者传递到内核。在这种情况下,密钥可以通过 VMEM 传递到内核,如下所示:

def body(key_ref, o_ref):
  key = key_ref[...]
  o_ref[...] = jax_random.uniform(
      key, shape=o_ref[...].shape, minval=0.0, maxval=1.0
  )

threefry_key = jax_random.key(0, impl="threefry2x32")

# We generate a threefry key outside of the kernel and pass it in via VMEM.
result = pl.pallas_call(
    body,
    in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)],
    out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32)
)(threefry_key)

注意

就性能而言,在内核中生成随机数有助于减少内存带宽使用量,因为传递密钥比传递大量随机数数组更便宜。但是,threefry2x32 是一种向量密集型算法,涉及数十个链式按位运算。这可能成为瓶颈,并导致加速器利用率低,因为它不利用矩阵乘法单元 (MXU),而 MXU 占 FLOP/s 的绝大部分。

使用硬件 PRNG#

TPU 在硬件中原生实现了一个顺序(而非基于计数器)的 PRNG,其计算速度比使用 threefry2x32 等软件实现的 PRNG 快得多。然而,JAX 随机 API 假定使用无状态的、基于计数器的 PRNG,因此 Pallas 引入了自己的有状态 PRNG API 来提供等效功能。

警告

硬件 PRNG 的底层实现因 TPU 代际而异,因此最好不要依赖其确切行为。对于更稳定的软件实现的 PRNG,建议使用 threefry2x32 实现。

有状态的随机数生成#

使用 Pallas PRNG 的有状态模式是生成随机数最原生、最高效的方法。首先,应使用 pltpu.prng_seed(N) 设置 PRNG 种子,其中 N 是一个整数种子。

之后,您可以调用任意数量的有状态采样函数,这些函数等同于相应的 JAX 版本,但缺少 key 参数。

生成任何随机数都会更新 PRNG 的内部状态,后续调用将生成不同的数字。与 JAX 不同,无需 splitfold_in 密钥并将其传递给采样函数。

例如,以下内核生成一组 0 到 1 之间的均匀随机数。

from jax.experimental.pallas import tpu as pltpu

def kernel_body(o_ref):
  pltpu.prng_seed(0)
  o_ref[...] = pltpu.stateful_uniform(shape=o_ref.shape, minval=0.0, maxval=1.0)

pl.pallas_call(kernel_body,
               out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32))

请注意,在带有网格的内核中,种子应该只在第一次迭代时设置,否则由于重置种子,每个程序实例中生成的随机数将是相同的。

无状态生成#

Pallas 提供了一个介于前面描述的无状态 API 和 jax.random 无状态 API 之间的中间 API,允许您以无状态方式使用硬件 PRNG。为此,请通过 pltpu.to_pallas_key(key) 将 JAX 密钥转换为特殊的 Pallas 类型密钥,并通过 SMEM 将此密钥传递到内核。一旦密钥在内核中被解引用,就可以将其传递给 jax.random 中的支持的采样函数以生成随机数。与无状态 API 相比,每次调用随机数生成器时都会产生计算和设置种子的开销。

例如,以下内核使用硬件 PRNG 绘制均匀随机数。

def body(key_ref, o_ref):
  o_ref[...] = jax.random.uniform(
      key_ref[...], shape=o_ref[...].shape
  )

rbg_key = jax_random.key(0, impl="threefry2x32")
key = pltpu.to_pallas_key(rbg_key)
o_shape = jax.ShapeDtypeStruct((8, 128), dtype)
result = pl.pallas_call(
    body,
    in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
    out_shape=o_shape,
)(key)

对于带有网格的较大内核,可以使用 program_id 上的 jax.random.fold_in() 为每个程序实例生成唯一的密钥。

块不变采样#

块不变采样是一种以块为单位生成随机数的方法,该方法对于使用的块大小和迭代顺序是不变的。例如,您可能希望在两个内核(例如前向和后向传递)之间生成相同的随机数集,但这两个内核可能具有在调优后选择的不同块大小。

Pallas 提供了一个辅助函数(pltpu.sample_block),允许您保证在不同的块和网格设置下绘制相同的随机数。第一步是选择一个 tile_size,这是一个可被所有您希望保持不变的块大小整除的块。例如,tile_size=(16, 128) 将适用于块大小为 (32, 128)(16, 256) 的情况。tile_size 越大,采样过程就越有效,因此所有潜在块大小的最大公约数是最佳选择。

接下来,使用以下参数调用 pltpu.sample_block

pltpu.sample_block(
  sampler_function,  # A JAX random function, such as `jax.random.uniform`.
  global_key,  # A global key shared across all blocks.
  block_size,  # The local block size to generate.
  tile_size,  # The tile size.
  total_size,  # The total shape of the generated array across all blocks.
  block_index,  # The block index into total_size. Usually this is the current program instance.
  **sampler_kwargs  # Keyword arguments to sampler_function
)

例如,以下代码片段在 (16, 128) 的块形状和 (32, 256) 的块形状以及转置的网格迭代顺序下生成相同的随机数。

def make_kernel_body(index_map):
  def body(key_ref, o_ref):
    key = key_ref[...]
    samples = pltpu.sample_block(
        jax.random.uniform,
        key,
        block_size=o_ref[...].shape,
        tile_size=(16, 128),
        total_size=(64, 512),
        block_index=index_map(pl.program_id(0), pl.program_id(1)),
        minval=0.0,
        maxval=1.0)
    o_ref[...] = samples
  return body

global_key = pltpu.to_pallas_key(jax_random.key(0))
o_shape = jnp.ones((64, 512), dtype=jnp.float32)
key_spec = pl.BlockSpec(memory_space=pltpu.SMEM)
out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j))
result_16x128 = pl.pallas_call(
    make_kernel_body(index_map=lambda i, j: (i, j)),
    out_shape=o_shape,
    in_specs=[key_spec],
    out_specs=out_spec,
    grid=(4, 4),
)(global_key)

out_spec = pl.BlockSpec((32, 256), lambda i, j: (j, i))
result_32x256_transposed = pl.pallas_call(
    make_kernel_body(index_map=lambda i, j: (j, i)),
    in_specs=[key_spec],
    out_shape=o_shape,
    out_specs=out_spec,
    grid=(2, 2),
)(global_key)