jax.experimental.multihost_utils.process_allgather#

jax.experimental.multihost_utils.process_allgather(in_tree, tiled=False)[source]#

从跨进程收集数据。

参数:
  • in_tree (Any) – 数组的 pytree - 每个数组_必须_在所有主机上具有相同的形状。

  • tiled (bool) – 是否堆叠或连接输出。默认为 False,即堆叠到索引 0 处的新位置轴。

返回:

numpy 数组的 Pytrees。
  • 如果输入是不可完全寻址的 jax.Array,则数据将被完全复制。

  • 如果输入是 numpy 数组或完全可寻址的 jax.Array,则输出形状取决于 tiled 参数。如果为 False,则输出将被堆叠,否则将被连接。

  • 如果输入是标量,则输出将被堆叠。

返回类型:

Any