jax.lax.psum_scatter#

jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[source]#

类似于 psum(x, axis_name),但每个设备仅保留部分结果。

例如,psum_scatter(x, axis_name, scatter_dimension=0, tiled=False) 计算的值与 psum(x, axis_name)[axis_index(axis_name)] 相同,但效率更高。因此,psum 的结果会分散在映射的轴上。

计算 psum(x, axis_name) 的一种高效算法是执行 psum_scatter,然后进行 all_gather,本质上是计算 all_gather(psum_scatter(x, axis_name))。因此,我们可以将 psum_scatter 视为 psum 的“前半部分”。

参数:
  • x – 具有名为 axis_name 的映射轴的数组。

  • axis_name – 用于命名映射轴的可哈希 Python 对象(有关更多详细信息,请参阅 jax.pmap() 文档)。

  • scatter_dimension – 一个位置轴,沿 axis_name 的 all-reduce 结果将分散到该轴中。

  • axis_index_groups – 可选的整数列表的列表,包含轴索引。例如,对于大小为 4 的轴,axis_index_groups=[[0, 1], [2, 3]] 将在前两个和后两个轴索引上运行 reduce-scatter。组必须完全覆盖所有轴索引一次,并且所有组的大小必须相同。

  • tiled – 布尔值,表示是否使用保留秩的“平铺”行为。当 False (默认值)时,scatter_dimension 中的维度大小必须与轴 axis_name 的大小(如果给定 axis_index_groups,则为组大小)匹配。在沿 scatter_dimension 分散 all-reduce 结果后,通过删除 scatter_dimension 来挤压输出,因此结果的秩低于输入。当 True 时,scatter_dimension 中的维度大小必须可以被轴 axis_name 的大小(如果给定 axis_index_groups,则为组大小)整除,并且 scatter_dimension 轴将被保留(因此结果的秩与输入相同)。

返回:

形状与 x 相似的数组,不同之处在于位置 scatter_dimension 中的维度大小除以轴 axis_name 的大小(当 tiled=True 时),或者位置 scatter_dimension 中的维度被消除(当 tiled=False 时)。

例如,如果有 4 个 XLA 设备可用

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x)
>>> print(y)
[24 28 32 36]

如果使用平铺

>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x)
>>> print(y)
[[24]
 [28]
 [32]
 [36]]

使用 axis_index_groups 的示例

>>> def f(x):
...   return jax.lax.psum_scatter(
...       x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True)
>>> y = jax.pmap(f, axis_name='i')(x)
>>> print(y)
[[ 8 10]
 [20 22]
 [12 14]
 [16 18]]