jax.lax.associative_scan#
- jax.lax.associative_scan(fn, elems, reverse=False, axis=0)[源代码]#
使用结合律二元运算并行执行扫描。
有关结合律扫描的介绍,请参见 [BLE1990]。
- 参数:
fn (Callable) –
实现带有签名
r = fn(a, b)
的结合律二元运算的 Python 可调用对象。函数 fn 必须具有结合律,即,它必须满足方程式fn(a, fn(b, c)) == fn(fn(a, b), c)
。输入和结果是(可能是嵌套的 Python 树结构)与
elems
匹配的数组。 每个数组都有一个维度来代替axis
维度。 fn 应以元素方式应用于axis
维度(例如,通过对元素函数使用jax.vmap()
)。结果
r
具有与两个输入a
和b
相同的形状(和结构)。elems – 一个(可能是嵌套的 Python 树结构)数组,每个数组都有一个大小为
num_elems
的axis
维度。reverse (bool) – 一个布尔值,声明扫描是否应相对于
axis
维度反转。axis (int) – 一个整数,用于标识应发生扫描的轴。
- 返回:
一个(可能是嵌套的 Python 树结构)数组,其形状和结构与
elems
相同,其中axis
的第k
个元素是递归应用fn
以组合elems
沿axis
的前k
个元素的结果。 例如,给定elems = [a, b, c, ...]
,结果将是[a, fn(a, b), fn(fn(a, b), c), ...]
。如果
elems = [..., x, y, z]
且reverse
为真,则结果为[..., f(f(z, y), x), f(z, y), z]
。
示例 1:数字数组的部分和
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) Array([0, 1, 3, 6], dtype=int32)
示例 2:矩阵数组的部分积
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2)
示例 3:数字数组的反向部分和
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) Array([6, 6, 5, 3], dtype=int32)
[BLE1990]Blelloch, Guy E. 1990. “前缀和及其应用”,技术报告 CMU-CS-90-190,卡内基梅隆大学计算机科学学院。