jax.lax.psum#

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[源]#

在 pmap 轴 axis_name 上,对 x 执行全规约求和。

如果 x 是一个 pytree,则结果等同于将此函数映射到树中的每个叶子。

布尔数据类型的输入在规约之前会转换为整数。

参数:
  • x – 具有名为 axis_name 的映射轴的数组。

  • axis_name – 可哈希的 Python 对象,用于命名 pmap 轴(详见 jax.pmap() 文档)。

  • axis_index_groups – 可选的列表的列表,包含轴索引(例如,对于大小为 4 的轴,[[0, 1], [2, 3]] 将对前两个和后两个副本执行 psum)。组必须精确地覆盖所有轴索引一次。

返回:

形状与 x 相同的数组,表示沿着 axis_name 轴的全规约求和结果。

示例

例如,如果有 4 个 XLA 设备可用

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]

假设我们想在两个组之间执行 psum,一个组包含 device0device1,另一个组包含 device2device3

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[1 1 5 5]

一个使用 2D 形状 x 的示例。每行是来自一个设备的数据。

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

在所有设备上的完全 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[[24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]]

在两个组之间执行 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[[ 4  6  8 10]
 [ 4  6  8 10]
 [20 22 24 26]
 [20 22 24 26]]