jax.lax.precv#

jax.lax.precv(token, out_shape, axis_name, perm)[source]#

根据置换 perm 执行集体接收 (recv)。

此函数是 Recv HLO 的一个类似物。

参数:
  • token – 一个编译器 token,由匹配的 psend 或 lax.create_token() 生成。 它用于强制集合体之间的控制依赖关系。

  • out_shape – 包含结果的 dtype 和形状的 ShapeDtypeStruct(s)。

  • axis_name – 可哈希的 Python 对象,用于命名 pmapped 轴(有关更多详细信息,请参阅 jax.pmap() 文档)。

  • perm – 整数对的列表,表示 (source_index, destination_index) 对,用于编码命名为 axis_name 的映射轴应如何混洗。 整数值被视为映射轴 axis_name 的索引。 任何两个对不应具有相同的源索引或相同的目标索引。 对于轴 axis_name 的每个索引,如果该索引不对应于 perm 中的目标索引,则结果中的相应值将填充适当类型的零。 此处的语义是平台特定的,对于 GPU,它们对应于 NCCL recv。

返回:

out_shape 具有相同形状的数组。