jax.lax.ppermute#
- jax.lax.ppermute(x, axis_name, perm)[source]#
根据排列
perm
执行集体置换。如果
x
是一个 pytree,那么结果等同于将此函数映射到树中的每个叶节点。此函数是 CollectivePermute HLO 的一个类似物。
- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(有关更多详细信息,请参阅
jax.pmap()
文档)。perm – 表示
(source_index, destination_index)
对的整数对列表,用于编码应如何混洗名为axis_name
的映射轴。整数值被视为映射轴axis_name
的索引。任何两个对都不应具有相同的源索引或相同的目标索引。对于轴axis_name
的每个索引,如果该索引不对应于perm
中的目标索引,则结果中的对应值将填充适当类型的零。
- 返回:
与
x
具有相同形状的数组,其切片沿着轴axis_name
从x
中根据排列perm
收集而来。