伪随机数生成#
Pallas TPU 在内核中实现了几种用于生成伪随机数的 API,在可移植性和效率方面具有不同的权衡。为了最大程度的可移植性,请考虑直接使用 jax.random 函数。Pallas 还公开了 TPU 上包含的硬件 PRNG,其计算速度最快,但底层实现可能因硬件代际而异。
使用 jax.random
API#
Pallas 支持 jax.random
API 中的一部分操作。当给定相同的密钥时,这些函数保证会产生与在 Pallas 外部调用 JAX 中的这些函数相同的按位相等的结果。只支持 threefry2x32
密钥。
当前支持以下随机采样函数
支持以下实用函数
PRNG 密钥可以通过 jax.random.key()
在内核中生成。但是,更常见的情况是将密钥从调用者传递到内核。在这种情况下,密钥可以通过 VMEM 传递到内核,如下所示:
def body(key_ref, o_ref):
key = key_ref[...]
o_ref[...] = jax_random.uniform(
key, shape=o_ref[...].shape, minval=0.0, maxval=1.0
)
threefry_key = jax_random.key(0, impl="threefry2x32")
# We generate a threefry key outside of the kernel and pass it in via VMEM.
result = pl.pallas_call(
body,
in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)],
out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32)
)(threefry_key)
注意
就性能而言,在内核中生成随机数有助于减少内存带宽使用量,因为传递密钥比传递大量随机数数组更便宜。但是,threefry2x32
是一种向量密集型算法,涉及数十个链式按位运算。这可能成为瓶颈,并导致加速器利用率低,因为它不利用矩阵乘法单元 (MXU),而 MXU 占 FLOP/s 的绝大部分。
使用硬件 PRNG#
TPU 在硬件中原生实现了一个顺序(而非基于计数器)的 PRNG,其计算速度比使用 threefry2x32
等软件实现的 PRNG 快得多。然而,JAX 随机 API 假定使用无状态的、基于计数器的 PRNG,因此 Pallas 引入了自己的有状态 PRNG API 来提供等效功能。
警告
硬件 PRNG 的底层实现因 TPU 代际而异,因此最好不要依赖其确切行为。对于更稳定的软件实现的 PRNG,建议使用 threefry2x32
实现。
有状态的随机数生成#
使用 Pallas PRNG 的有状态模式是生成随机数最原生、最高效的方法。首先,应使用 pltpu.prng_seed(N)
设置 PRNG 种子,其中 N 是一个整数种子。
之后,您可以调用任意数量的有状态采样函数,这些函数等同于相应的 JAX 版本,但缺少 key
参数。
pltpu.stateful_uniform
:有状态的jax.random.uniform()
等效函数。pltpu.stateful_normal
:有状态的jax.random.normal()
等效函数。pltpu.stateful_bernoulli
:有状态的jax.random.bernoulli()
等效函数。
生成任何随机数都会更新 PRNG 的内部状态,后续调用将生成不同的数字。与 JAX 不同,无需 split
或 fold_in
密钥并将其传递给采样函数。
例如,以下内核生成一组 0 到 1 之间的均匀随机数。
from jax.experimental.pallas import tpu as pltpu
def kernel_body(o_ref):
pltpu.prng_seed(0)
o_ref[...] = pltpu.stateful_uniform(shape=o_ref.shape, minval=0.0, maxval=1.0)
pl.pallas_call(kernel_body,
out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32))
请注意,在带有网格的内核中,种子应该只在第一次迭代时设置,否则由于重置种子,每个程序实例中生成的随机数将是相同的。
无状态生成#
Pallas 提供了一个介于前面描述的无状态 API 和 jax.random
无状态 API 之间的中间 API,允许您以无状态方式使用硬件 PRNG。为此,请通过 pltpu.to_pallas_key(key)
将 JAX 密钥转换为特殊的 Pallas 类型密钥,并通过 SMEM 将此密钥传递到内核。一旦密钥在内核中被解引用,就可以将其传递给 jax.random
中的支持的采样函数以生成随机数。与无状态 API 相比,每次调用随机数生成器时都会产生计算和设置种子的开销。
例如,以下内核使用硬件 PRNG 绘制均匀随机数。
def body(key_ref, o_ref):
o_ref[...] = jax.random.uniform(
key_ref[...], shape=o_ref[...].shape
)
rbg_key = jax_random.key(0, impl="threefry2x32")
key = pltpu.to_pallas_key(rbg_key)
o_shape = jax.ShapeDtypeStruct((8, 128), dtype)
result = pl.pallas_call(
body,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
out_shape=o_shape,
)(key)
对于带有网格的较大内核,可以使用 program_id
上的 jax.random.fold_in()
为每个程序实例生成唯一的密钥。
块不变采样#
块不变采样是一种以块为单位生成随机数的方法,该方法对于使用的块大小和迭代顺序是不变的。例如,您可能希望在两个内核(例如前向和后向传递)之间生成相同的随机数集,但这两个内核可能具有在调优后选择的不同块大小。
Pallas 提供了一个辅助函数(pltpu.sample_block
),允许您保证在不同的块和网格设置下绘制相同的随机数。第一步是选择一个 tile_size
,这是一个可被所有您希望保持不变的块大小整除的块。例如,tile_size=(16, 128)
将适用于块大小为 (32, 128)
和 (16, 256)
的情况。tile_size 越大,采样过程就越有效,因此所有潜在块大小的最大公约数是最佳选择。
接下来,使用以下参数调用 pltpu.sample_block
:
pltpu.sample_block(
sampler_function, # A JAX random function, such as `jax.random.uniform`.
global_key, # A global key shared across all blocks.
block_size, # The local block size to generate.
tile_size, # The tile size.
total_size, # The total shape of the generated array across all blocks.
block_index, # The block index into total_size. Usually this is the current program instance.
**sampler_kwargs # Keyword arguments to sampler_function
)
例如,以下代码片段在 (16, 128) 的块形状和 (32, 256) 的块形状以及转置的网格迭代顺序下生成相同的随机数。
def make_kernel_body(index_map):
def body(key_ref, o_ref):
key = key_ref[...]
samples = pltpu.sample_block(
jax.random.uniform,
key,
block_size=o_ref[...].shape,
tile_size=(16, 128),
total_size=(64, 512),
block_index=index_map(pl.program_id(0), pl.program_id(1)),
minval=0.0,
maxval=1.0)
o_ref[...] = samples
return body
global_key = pltpu.to_pallas_key(jax_random.key(0))
o_shape = jnp.ones((64, 512), dtype=jnp.float32)
key_spec = pl.BlockSpec(memory_space=pltpu.SMEM)
out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j))
result_16x128 = pl.pallas_call(
make_kernel_body(index_map=lambda i, j: (i, j)),
out_shape=o_shape,
in_specs=[key_spec],
out_specs=out_spec,
grid=(4, 4),
)(global_key)
out_spec = pl.BlockSpec((32, 256), lambda i, j: (j, i))
result_32x256_transposed = pl.pallas_call(
make_kernel_body(index_map=lambda i, j: (j, i)),
in_specs=[key_spec],
out_shape=o_shape,
out_specs=out_spec,
grid=(2, 2),
)(global_key)