标量预取和块稀疏计算#

在本教程中,我们将介绍 Pallas 中块稀疏计算的基础知识。稀疏计算是编写自定义 Pallas 内核而不是简单地使用 JAX/XLA 的一个主要原因,因为通常很难在 XLA 中表达执行动态计算量的程序,这归因于静态数组形状。在本教程中,我们将学习如何使用 Pallas 的标量预取功能来编写块稀疏内核,这些内核可以动态跳过计算和内存块。

import functools
import timeit
import numpy as np
import jax
from jax import numpy as jnp
from jax import lax
from jax.experimental import checkify
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print("Running on", jax.devices()[0].device_kind)
Running on TPU v5 lite

使用标量预取的动态块索引#

我们将利用 Pallas 的“标量预取”功能来使我们能够编写稀疏内核。标量预取允许您将少量数据传递到 SMEM(“标量内存”),这些数据在管道开始之前加载(“预取”)。由于此数据在管道之前加载,因此它可用于每个 BlockSpec 的 index_map 中,从而允许您执行数据相关的索引计算。本教程的主要目标是回顾利用此功能的常见编程模式。

要使用标量预取,请使用 pltpu.PrefetchScalarGridSpec 代替标准 pl.GridSpec

class PrefetchScalarGridSpec:
  def __init__(self,
    num_scalar_prefetch: int,
    grid: tuple[int, ...],
    in_specs: PyTree[BlockSpec],
    out_specs: PyTree[BlockSpec],
    scratch_shapes: tuple[MemorySpace, ...]):
      ...

num_scalar_prefetch 参数指示标量预取值的数量。当设置为非零值时,它会更改内核和索引映射的调用签名,以期望额外的预取值。传递到 index_map 和内核的预取 Ref 都分配在 SMEM 中,并且不分区为块,因为它们没有定义 BlockSpec。此外,index_map 和内核的参数顺序始终是固定的,如下所述

  • 现在,每个 BlockSpecindex_map 期望预取 Ref 在网格索引之后出现

def index_map(*grid_indices, *prefetch_refs):
    ...
  • 用户定义的内核期望预取 Ref 在输入 Ref 之前出现。此外,暂存 refs 在输出 Ref 之后出现。

def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
    ...
  • 当使用 pallas_call 调用新内核时,pallas_call 返回的函数也期望标量预取参数在输入之前出现,例如

kernel = pl.pallas_call(...)
result = kernel(*prefetch_args, *input_args)

示例:使用标量预取的块动态切片#

让我们从一个基本示例开始,该示例演示如何使用标量预取功能。我们将实现一个块对齐的动态切片内核,该内核仅根据用户指定的索引从较大的数组中提取一个块

  1. 在内核外部,我们将要提取的块索引计算为:block_idx = (start[0] // size[0], start[1] // size[1])

  2. 我们将 block_idx 作为标量预取参数传递到 pallas_call 中。

  3. 在我们的索引映射中,我们使用块索引来选择相应的块,方法是返回 (block_idx[0], block_idx[1])

当然,此内核受到限制,因为我们的切片大小必须适合内核块内部(受 VMEM 大小的限制),并且我们只能在大小对齐的索引上启动。更高级的内核会将内核块大小与切片大小解耦,并允许非对齐的起始索引。

def dynamic_slice_kernel(indices, x_ref, o_ref):
  del indices
  o_ref[...] = x_ref[...]

@checkify.checkify
@functools.partial(jax.jit, static_argnums=(2,))
def block_dynamic_slice(x, starts, sizes):
  grid_spec = pltpu.PrefetchScalarGridSpec(
      num_scalar_prefetch=1,
      grid=(1, 1),
      in_specs=[pl.BlockSpec(
          sizes,
          lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
      out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
  )

  kernel = pl.pallas_call(
    dynamic_slice_kernel,
    grid_spec=grid_spec,
    out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
  )
  # Checkify inserts a runtime assert that starts are divisible by block size.
  checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
  checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
  block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
  return kernel(block_idx, x)

shape = (512, 512)
x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
err.throw()
ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
diff = jnp.max(jnp.abs(result - ref))
print("Error |result - lax.dynamic_slice| =", diff)
Error |result - lax.dynamic_slice| = 0

稀疏内核:表示稀疏数据#

在我们深入实现稀疏内核之前,让我们首先回顾一下稀疏矩阵是如何表示的。虽然有几种流行的格式用于存储稀疏矩阵,但我们将遵循坐标列表格式 (COO) 的块变体,其中我们将矩阵存储为 (block_index, block_data) 对的列表。所有未在列表中显式存储的块都假定为零,这意味着如果矩阵中有很多零块,我们可以节省大量内存。

下图演示了我们如何将 4x4 稠密矩阵(左)转换为块 COO 格式(右),块大小为 2x2。请注意,在稀疏格式中,我们可以避免显式存储由所有零元素组成的右上角块。

block_coo

我们将使用以下辅助函数来采样块稀疏矩阵。它返回一个用于检查结果的稠密矩阵,以及每个轴的块数据和索引列表。

def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
  """Returns a sampled matrix and its block-sparse representation.

  Args:
    key: RNG Key.
    M: Major array dimension.
    N: Minor array dimension.
    blk_M: Block size along M dimension.
    blk_N: Block size along N dimension.
    p: Probability that a block will be non-zero.
    dtype: dtype of the sampled matrix.

  Returns:
    dense_mat: A (M, N) dense sampled array.
    block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
      the non-zero blocks of the matrix.
    indices_i: A (num_blocks,) array of block indices for the first axis.
    indices_j: A (num_blocks,) array of block indices for the second axis.
  """
  mask_key, blocks_key = jax.random.split(key)
  num_blocks = (M // blk_M, N // blk_N)
  # We first sample a block mask, denoting which blocks are nonzero.
  block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
  num_blocks = jnp.sum(block_mask)
  indices = jnp.where(block_mask)
  # For each non-zero block, we sample a block of random values.
  block_data = jax.random.uniform(blocks_key,
                                  shape=(num_blocks, blk_M, blk_N),
                                  dtype=dtype)
  # For checking purposes, create the dense version of the sparse matrix.
  dense_mat = jnp.zeros((M, N), dtype=dtype)
  for blk in range(num_blocks):
    idx_i = indices[0][blk]
    idx_j = indices[1][blk]
    slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
    slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
    dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
  return dense_mat, block_data, indices[0], indices[1]

示例:稀疏 @ 稠密矩阵乘法#

在我们的第一个示例中,我们将把稀疏 LHS 矩阵与稠密 RHS 矩阵相乘,以产生稠密输出。

我们将使用 2 个循环来构建内核网格 - 外循环遍历 RHS/输出的列,内循环遍历 LHS 的稀疏块。在每个内循环迭代期间,我们从 LHS 加载一个块,并使用收缩维度 (K) 的块索引在 RHS 中查找相应的块。我们将两个块相乘并累加到正确的输出块中。一个外循环迭代将计算整个列的结果,如下图所示

sparse_matmul

重要的是,我们在将块索引传递到内核之前,按行对块索引进行分组(例如 [0, 0, 1, 2, 3, 3])。首先,在我们的内核中,我们需要知道何时最初将累加器归零到输出 ref 中,如果行索引已分组,则很容易做到这一点。其次,Pallas 的流水线逻辑不允许我们在非连续迭代中重新访问输出 Ref 中的块,因此我们需要在连续的内核迭代中完成对输出块的所有累加。这是因为流水线发射器将意识到我们在连续迭代中加载相同的输出块,并将该块保留在 VMEM 中。当我们更改输出块时,Pallas 最终会将输出存储到 HBM 中,并假设我们永远不会再触碰它。即使内核在逻辑上是正确的,未能连续访问输出块也会导致值不正确。

M = N = K = 16384
blk_M = blk_N = blk_K = 512


def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
               x_ref, y_ref, _, o_ref, # Kernel inputs.
               accum_scratch,
               ):
  """A DSD (Dense = Sparse @ Dense) matmul kernel."""
  del idxs_k_ref
  blk_idx = pl.program_id(1)
  is_start = blk_idx == 0
  changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
  @pl.when(is_start | changed_blocks)
  def _():
    accum_scratch[...] = jnp.zeros_like(accum_scratch)
  accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)

  next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
  is_end = blk_idx == (num_blocks - 1)
  @pl.when(is_end | next_block_change)
  def _():
    o_ref[...] = accum_scratch[...].astype(o_ref.dtype)


def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
  del j, blk_idxs_i, blk_idxs_k
  return (blk_idx, 0, 0)
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
  del blk_idxs_i
  return (blk_idxs_k[blk_idx], j)
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
  del blk_idxs_k
  return (blk_idxs_i[blk_idx], j)

(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
    jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
num_blocks = X_blocks.shape[0]
Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=2,
    # Note that while num_blocks is static here, Pallas does support
    # dynamic grid sizes.
    grid=(N // blk_N, num_blocks),
    in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              # Placeholder for a zeros-array used by input_output_aliases.
              pl.BlockSpec((blk_M, blk_N), o_map),
              ],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  dsd_kernel,
  grid_spec=grid_spec,
  out_shape=out_shape,
  # We use input-output aliases to zero-out o_ref for blocks that we never
  # visit. By passing in an array of zeros we avoid having o_ref start with
  # uninitialized values.
  input_output_aliases={4: 0},  # Map zeros to o_ref.
)
args = (indices_i, indices_k, X_blocks, Y, zeros)
result = kernel(*args)

ref = X_dense @ Y
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 0

我们可以做一个快速基准测试,以比较我们的稀疏内核与 JAX 中稠密 matmul 的性能。在 TPU v5e 芯片上,与稀疏因子理论上的 10 倍相比,此内核实现了大约 6 倍的速度提升。

这里有一些主要的性能提示,主要围绕减少 HBM/VMEM 之间的通信开销

  • 使用 dtype=jnp.bfloat16 对于性能至关重要,因为它将内存带宽减少了一半。

  • 使用更大的块大小也有帮助,因为矩阵乘法是 \(O(N^3)\) 计算和 \(O(N^2)\) 内存操作。随着 \(N\) 变大,内核变得受计算限制。但是,在实践中,对此的反驳是较小的块大小也使数据更稀疏,因此这是一个应仔细选择的参数。

# Benchmark Sparse Pallas kernel vs reference JAX implementation

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    return time
  return run


n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda x, y: x @ y)
time = benchmark(ref_impl, n_trials)(X_dense, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 8.136 ms (avg over 100 trials)
Reference: 46.953 ms (avg over 100 trials)

稠密数据上的稀疏访问模式#

在我们之前的示例中,我们考虑了数据本身是稀疏的情况。这在内核结构中表现为一个内核网格中的维度是动态的,并循环遍历非零块的数量 (num_blocks)。

当底层数据是稠密的,但我们希望对其执行稀疏计算时,会出现第二种有用的编程模式。在这种情况下,我们的内核网格将是稠密的,但我们希望跳过网格中的某些块,如块稀疏掩码所示。当在许多机器学习应用中使用掩码时,例如自注意力中的因果掩码或局部掩码,通常会出现这种编程模式。在这些情况下,我们可以完全跳过掩码归零的块中的计算。jax/experimental/pallas/ops/tpu 中的 Splash Attention 和 Grouped Matrix Multiplication 内核或 PyTorch 的 FlexAttention 中可以找到这种编程模式的示例。

在稠密数据上处理稀疏访问模式的主要性能考虑因素是与流水线的交互。在任何给定的内核迭代中,Pallas 流水线发射器将尝试通过为网格的下一次迭代上的每个 BlockSpec 调用 index_map 来预取下一个数据块。但是,如果我们的计算是稀疏的,我们可能会跳过网格中下一个块的计算,因此我们需要某种方法来告诉流水线开始获取 *我们不跳过的下一个块*。为了做到这一点,我们需要构建 *预取映射*,其中包含每个内核输入的下一个非跳过数据块的索引。下图说明了如何为以类似 COO 格式存储的块稀疏掩码构建预取映射。

prefetch_map

左图:稀疏访问模式,其中蓝色表示我们需要计算的具有非零掩码的块。右图:预取映射,其中数组的每个元素都包含下一个非零块数据的索引。

构造预取映射后,我们可以将该映射作为标量预取参数传递,并在 BlockSpec 的 index_map 函数中查询它。

def mask_index_map(prefetch_map, i, j, ...):
  next_nonzero_block = prefetch_map[i, j]
  return (next_nonzero_block, 0, 0)

我们可以为内核的其他输入构造类似的索引映射。对于稠密输入,您很可能需要构造预取映射,该映射指向网格中下一个非零块索引。我们的下一个示例将提供使用这些预取映射的示例。

示例:使用块稀疏输出掩码的稠密 @ 稠密矩阵乘法#

在我们的下一个示例中,我们将介绍稠密矩阵乘法与稀疏输出掩码的融合,使用预取映射来提高流水线性能。我们将使用掩码来选择性地跳过计算归零的输出块,从而节省计算成本。

由于我们将使用稀疏掩码,因此我们将首先实现一个函数,该函数将以稠密格式存储的 N x M 掩码转换为块稀疏格式。此外,我们需要计算预取映射,以帮助流水线发射器知道接下来要获取哪个块。总而言之,我们的 sparsify_mask 函数计算

  • 形状为 (num_N_blocks, num_M_blocks)block_mask,指示块是否全为零(值 0)或包含非零元素(值 1)。如果 block_mask 的值为 0,我们可以跳过内核中块的计算。

  • 形状为 (num_N_blocks, num_M_blocks)prefetch_mask 数组,由 mask_data 中下一个非零块的索引组成。

  • 形状为 (num_N_blocks, num_M_blocks)prefetch_i 数组,由掩码的下一个非掩码 i 索引组成。

  • 形状为 (num_N_blocks, num_M_blocks)prefetch_j 数组,由掩码的下一个非掩码 j 索引组成。

  • 形状为 (num_blocks, blk_N, blk_M)mask_data 数组,包含掩码的非零块的数据。

def sparsify_mask(mask: jax.Array,
                  block_shape: tuple[int, int]):
  """Preprocesses a mask into a sparse reprentation.

  Args:
    mask: A boolean array of shape [M, N]
    block_shape: The size of a single block.

  Returns:
    block_mask: A block_shape array of booleans indicating whether a block
      is all-zeros (0) or contains non-zero elements (1).
    prefetch_mask: A block_shape array of integers indicating the index of the
      next non-zero block.
    mask_data: A (num_blocks, block_shape) array containing
      the data for non-zero blocks of the mask.
  """
  M, N = mask.shape
  bm, bn = block_shape

  block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
  mask_types_finder = []
  mask_data = []
  mask_type_idxs = []

  next_mask_type_idx = 0
  prefetch_mask = jnp.zeros_like(block_mask)
  next_i = (M // bm) - 1
  next_j = (N // bn) - 1
  prefetch_i = jnp.zeros_like(block_mask)
  prefetch_j = jnp.zeros_like(block_mask)
  for i in range(M // bm, -1, -1):
    for j in range(N // bn, -1, -1):
      mask_block = mask[i * bm :(i + 1) * bm,
                        j * bn :(j + 1) * bn]
      is_nonzero = jnp.any(mask_block)
      if is_nonzero:
        try:
          type_index = mask_types_finder.index(str(mask_block))
        except ValueError:
          type_index = len(mask_types_finder)
          mask_types_finder.append(str(mask_block))
          mask_data.append(mask_block)
        next_mask_type_idx = type_index
        next_i = i
        next_j = j
      else:
        type_index = -1
      mask_type_idxs.append(type_index)
      block_mask = block_mask.at[i, j].set(is_nonzero)
      prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
      prefetch_i = prefetch_i.at[i, j].set(next_i)
      prefetch_j = prefetch_j.at[i, j].set(next_j)
  return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)

在内核的结构方面,我们使用与我们在之前的教程中介绍的标准矩阵乘法内核相同的网格模式,在 NMK 维度上进行 3 个循环。在内核本身中,我们首先检查 block_mask,以查看当前输出块的掩码是否全为零。如果掩码全为零,我们可以跳过计算并转到下一个块;否则,我们需要计算矩阵乘法,然后对结果进行掩码。

M = N = K = 16384
blk_M = blk_N = 512
blk_K = 1024

def sparse_mask_matmul(
    block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
    x_ref, y_ref, mask_ref, o_ref,  # Kernel inputs.
    accum_scratch
    ):
  del prefetch_mask, prefetch_i, prefetch_j
  i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
  should_compute = block_mask_ref[i, j] != 0
  @pl.when(k == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

  # We only compute the output for blocks with non-zero masks.
  # Otherwise we skip the computation entirely.
  @pl.when(should_compute)
  def _():
    result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
    accum_scratch[...] += result
    @pl.when(k == pl.num_programs(2) - 1)
    def _():
      o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)

X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
mask = jnp.ones((M, N), dtype=jnp.int32)
mask = jnp.tril(mask)
block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))

def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_j
  # Zero-out the k index if the mask is zero, to avoid constantly fetching
  # new blocks in the inner loop for blocks we are skipping.
  k_fetch = (block_mask[i, j] != 0) * k
  return (prefetch_i[i, j], k_fetch)

def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_i
  k_fetch = (block_mask[i, j] != 0) * k
  return (k_fetch, prefetch_j[i, j])

def mask_map(i, j, k, block_mask, prefetch_mask, *_):
  del k, block_mask
  return (prefetch_mask[i, j], 0, 0)

def o_map(i, j, k, *_):
  del k
  return (i, j)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=4,
    grid=(M // blk_M, N // blk_N, K // blk_K),
    in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              pl.BlockSpec((1, blk_M, blk_N), mask_map)],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  sparse_mask_matmul,
  grid_spec=grid_spec,
  out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
)
args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
result = kernel(*args)

ref = mask * (X @ Y)
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 1.0252e-05

现在让我们比较性能与朴素的稠密实现。在 TPU v5e 上,与理论上最佳情况下的 2 倍(来自使用下三角掩码且仅访问一半可能的输出)相比,我们使用稀疏内核实现了大约 1.8 倍的速度提升。

我们通常期望性能随着输入变大而更接近理论峰值,因为我们没有完全达到理论性能的几个主要原因是

  • 我们跳过的计算略少于一半,因为沿对角线的块是 0 和 1 的混合,对于混合块,我们需要计算整个块。随着输入变大,我们的混合块开销相对于整体计算变得更小。

  • 随着输入变大,流水线气泡也占总体运行时间的百分比较小。

n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
time = benchmark(ref_impl, n_trials)(mask, X, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 28.648 ms (avg over 100 trials)
Reference: 49.988 ms (avg over 100 trials)