jax.lax.associative_scan#

jax.lax.associative_scan(fn, elems, reverse=False, axis=0)[source]#

使用结合性二元运算并行执行扫描。

有关结合性扫描的介绍,请参阅 [BLE1990]

参数:
  • fn (Callable) –

    一个 Python 可调用对象,实现结合性二元运算,签名 r = fn(a, b)。函数 fn 必须是结合性的,即它必须满足方程 fn(a, fn(b, c)) == fn(fn(a, b), c)

    输入和结果是(可能是嵌套的 Python 树结构)与 elems 匹配的数组。每个数组在 axis 维度位置都有一个维度。fn 应该在 axis 维度上逐元素应用(例如,通过使用 jax.vmap() 对逐元素函数进行向量化。)

    结果 r 具有与两个输入 ab 相同的形状(和结构)。

  • elems – 一个(可能是嵌套的 Python 树结构)数组,每个数组都有一个大小为 num_elemsaxis 维度。

  • reverse (bool) – 一个布尔值,指示扫描是否应该相对于 axis 维度反向。

  • axis (int) – 一个整数,标识扫描应发生的轴。

返回:

一个(可能是嵌套的 Python 树结构)数组,其形状和结构与 elems 相同,其中 axis 的第 k 个元素是通过递归应用 fn 来组合 elems 沿 axis 的前 k 个元素的结果。例如,给定 elems = [a, b, c, ...],结果将是 [a, fn(a, b), fn(fn(a, b), c), ...]

示例 1:数字数组的部分和

>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
Array([0, 1, 3, 6], dtype=int32)

示例 2:矩阵数组的部分积

>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)

示例 3:数字数组的反向部分和

>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
Array([6, 6, 5, 3], dtype=int32)
[BLE1990]

Blelloch, Guy E. 1990. “Prefix Sums and Their Applications.”, Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.