jax.lax.psum_scatter#
- jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[source]#
类似于
psum(x, axis_name)
,但每个设备仅保留部分结果。例如,
psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)
计算的值与psum(x, axis_name)[axis_index(axis_name)]
相同,但效率更高。因此,psum
结果沿映射轴分散。计算
psum(x, axis_name)
的一种高效算法是执行psum_scatter
,然后执行all_gather
,本质上是评估all_gather(psum_scatter(x, axis_name))
。因此,我们可以将psum_scatter
视为psum
的“前半部分”。- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 用于命名映射轴的可哈希 Python 对象(有关更多详细信息,请参阅
jax.pmap()
文档)。scatter_dimension – 一个位置轴,沿
axis_name
的 all-reduce 结果将分散到其中。axis_index_groups – 可选的整数列表的列表,包含轴索引。例如,对于大小为 4 的轴,
axis_index_groups=[[0, 1], [2, 3]]
将在前两个和后两个轴索引上运行 reduce-scatter。组必须精确覆盖所有轴索引一次,并且所有组的大小必须相同。tiled – 布尔值,表示是否使用保留秩的“平铺”行为。当
False
(默认值)时,scatter_dimension
中的维度大小必须与轴axis_name
的大小(或者如果给定axis_index_groups
,则为组大小)匹配。在沿scatter_dimension
分散 all-reduce 结果后,输出通过移除scatter_dimension
被压缩,因此结果的秩低于输入。当True
时,scatter_dimension
中的维度大小必须可被轴axis_name
的大小(或者如果给定axis_index_groups
,则为组大小)整除,并且scatter_dimension
轴被保留(因此结果与输入具有相同的秩)。
- 返回:
与
x
具有相似形状的数组,除了位置scatter_dimension
中的维度大小除以轴axis_name
的大小(当tiled=True
时),或者位置scatter_dimension
中的维度被消除(当tiled=False
时)。
例如,在有 4 个 XLA 设备可用的情况下
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x) >>> print(y) [24 28 32 36]
如果使用 tiled
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x) >>> print(y) [[24] [28] [32] [36]]
使用 axis_index_groups 的示例
>>> def f(x): ... return jax.lax.psum_scatter( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[ 8 10] [20 22] [12 14] [16 18]]