jax.lax.precv#
- jax.lax.precv(token, out_shape, axis_name, perm)[source]#
根据置换
perm
执行集体接收 (recv)。此函数是 Recv HLO 的一个类似物。
- 参数:
token – 一个编译器 token,由匹配的 psend 或 lax.create_token() 生成。 它用于强制集合体之间的控制依赖关系。
out_shape – 包含结果的 dtype 和形状的 ShapeDtypeStruct(s)。
axis_name – 可哈希的 Python 对象,用于命名 pmapped 轴(有关更多详细信息,请参阅
jax.pmap()
文档)。perm – 整数对的列表,表示
(source_index, destination_index)
对,用于编码命名为axis_name
的映射轴应如何混洗。 整数值被视为映射轴axis_name
的索引。 任何两个对不应具有相同的源索引或相同的目标索引。 对于轴axis_name
的每个索引,如果该索引不对应于perm
中的目标索引,则结果中的相应值将填充适当类型的零。 此处的语义是平台特定的,对于 GPU,它们对应于 NCCL recv。
- 返回:
与
out_shape
具有相同形状的数组。