jax.lax.psend#
- jax.lax.psend(x, axis_name, perm)[source]#
根据排列
perm
执行集体发送。如果
x
是一个 pytree,则结果等同于将此函数映射到树中的每个叶子。此函数是 Send HLO 的模拟。
- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 可哈希的 Python 对象,用于命名 pmapped 轴(有关更多详细信息,请参阅
jax.pmap()
文档)。perm – 整数对列表,表示
(source_index, destination_index)
对,用于编码名为axis_name
的映射轴应如何洗牌。整数值被视为映射轴axis_name
的索引。任何两个对都不应具有相同的源索引或相同的目标索引。对于轴axis_name
的每个索引,该索引不对应于perm
中的目标索引,结果中的相应值将填充适当类型的零。这里的语义是特定于平台的,对于 GPU,它们对应于 NCCL 发送。
- 返回:
编译器令牌,可由 precv 和 lax.optimzation_barrier 使用,以强制执行集体操作的排序。