jax.lax.psum_scatter#
- jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[source]#
类似于
psum(x, axis_name)
,但每个设备仅保留部分结果。例如,
psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)
计算的值与psum(x, axis_name)[axis_index(axis_name)]
相同,但效率更高。因此,psum
的结果会分散在映射的轴上。计算
psum(x, axis_name)
的一种高效算法是执行psum_scatter
,然后进行all_gather
,本质上是计算all_gather(psum_scatter(x, axis_name))
。因此,我们可以将psum_scatter
视为psum
的“前半部分”。- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 用于命名映射轴的可哈希 Python 对象(有关更多详细信息,请参阅
jax.pmap()
文档)。scatter_dimension – 一个位置轴,沿
axis_name
的 all-reduce 结果将分散到该轴中。axis_index_groups – 可选的整数列表的列表,包含轴索引。例如,对于大小为 4 的轴,
axis_index_groups=[[0, 1], [2, 3]]
将在前两个和后两个轴索引上运行 reduce-scatter。组必须完全覆盖所有轴索引一次,并且所有组的大小必须相同。tiled – 布尔值,表示是否使用保留秩的“平铺”行为。当
False
(默认值)时,scatter_dimension
中的维度大小必须与轴axis_name
的大小(如果给定axis_index_groups
,则为组大小)匹配。在沿scatter_dimension
分散 all-reduce 结果后,通过删除scatter_dimension
来挤压输出,因此结果的秩低于输入。当True
时,scatter_dimension
中的维度大小必须可以被轴axis_name
的大小(如果给定axis_index_groups
,则为组大小)整除,并且scatter_dimension
轴将被保留(因此结果的秩与输入相同)。
- 返回:
形状与
x
相似的数组,不同之处在于位置scatter_dimension
中的维度大小除以轴axis_name
的大小(当tiled=True
时),或者位置scatter_dimension
中的维度被消除(当tiled=False
时)。
例如,如果有 4 个 XLA 设备可用
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x) >>> print(y) [24 28 32 36]
如果使用平铺
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x) >>> print(y) [[24] [28] [32] [36]]
使用 axis_index_groups 的示例
>>> def f(x): ... return jax.lax.psum_scatter( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[ 8 10] [20 22] [12 14] [16 18]]