jax.lax.psum_scatter#
- jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[源代码]#
类似于
psum(x, axis_name)
,但每个设备仅保留部分结果。例如,
psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)
计算的值与psum(x, axis_name)[axis_index(axis_name)]
相同,但效率更高。因此,psum
结果沿映射轴分散。一种计算
psum(x, axis_name)
的有效算法是执行psum_scatter
,然后执行all_gather
,实质上是评估all_gather(psum_scatter(x, axis_name))
。因此,我们可以将psum_scatter
视为psum
的“前半部分”。- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 用于命名映射轴的可哈希 Python 对象(有关更多详细信息,请参见
jax.pmap()
文档)。scatter_dimension – 一个位置轴,沿
axis_name
的 all-reduce 结果将分散到该轴中。axis_index_groups – 可选的整数列表的列表,其中包含轴索引。例如,对于大小为 4 的轴,
axis_index_groups=[[0, 1], [2, 3]]
将在前两个和后两个轴索引上运行 reduce-scatter。组必须精确地覆盖所有轴索引一次,并且所有组的大小必须相同。tiled – 布尔值,表示是否使用保留秩的“tiled”行为。当
False
(默认值)时,scatter_dimension
中的维度大小必须与轴axis_name
的大小匹配(如果给定axis_index_groups
,则与组大小匹配)。在沿scatter_dimension
分散 all-reduce 结果后,通过删除scatter_dimension
来压缩输出,因此结果的秩低于输入。当True
时,scatter_dimension
中的维度大小必须可被轴axis_name
的大小整除(如果给定axis_index_groups
,则可被组大小整除),并且保留scatter_dimension
轴(因此结果与输入具有相同的秩)。
- 返回:
形状与
x
类似的数组,除了位置scatter_dimension
中的维度大小除以轴axis_name
的大小(当tiled=True
时),或者消除位置scatter_dimension
中的维度(当tiled=False
时)。
例如,如果有 4 个 XLA 设备可用
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x) >>> print(y) [24 28 32 36]
如果使用 tiled
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x) >>> print(y) [[24] [28] [32] [36]]
使用 axis_index_groups 的一个例子
>>> def f(x): ... return jax.lax.psum_scatter( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[ 8 10] [20 22] [12 14] [16 18]]