jax.experimental.multihost_utils.broadcast_one_to_all# jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[源码]# 将数据从源主机(默认为主机 0)广播到所有其他主机。 参数: in_tree (Any) – 数组的 Pytree - 跨主机的每个数组必须具有相同的形状。 is_source (bool | None) – 可选布尔值,表示调用者是否为源。只有“源主机”才会为广播提供数据。如果为 None,则使用主机 0。 返回: 一个与 in_tree 匹配的 Pytree,其中叶子现在都包含来自第一台主机的数据。 返回类型: 任意类型