jax.experimental.shard_map.shard_map#
- jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[源代码]#
将函数映射到数据的分片上。
注意
shard_map是一个实验性的 API,并且仍然可能会更改。有关分片数据的介绍,请参阅 并行计算简介。有关使用shard_map的更深入了解,请参阅 使用 shard_map 的 SPMD 多设备并行。- 参数:
f (Callable) – 要映射的可调用对象。每次应用
f,或f的“实例”,都将映射覆盖参数的分片作为输入,并生成输出的分片。mesh (Mesh | AbstractMesh) – 一个
jax.sharding.Mesh,表示用于对数据进行分片以及在其上执行f实例的设备阵列。Mesh的名称可以在f中的集体通信操作中使用。这通常由一个实用函数创建,例如jax.experimental.mesh_utils.create_device_mesh()。in_specs (Any) – 一个 pytree,其叶子节点为
PartitionSpec实例,其树结构是要映射覆盖的 args 元组的树前缀。与NamedSharding类似,每个PartitionSpec表示相应的参数(或参数子树)应如何沿mesh的命名轴进行分片。在每个PartitionSpec中,在某个位置提及mesh轴名称表示沿该位置轴对相应的参数数组轴进行分片;不提及轴名称表示复制。如果参数或参数子树的对应规范为 None,则该参数不分片。out_specs (Any) – 一个 pytree,其叶子节点为
PartitionSpec实例,其树结构是f输出的树前缀。每个PartitionSpec表示应如何连接相应的输出分片。在每个PartitionSpec中,在某个位置提及mesh轴名称表示沿相应的位置轴连接该 mesh 轴的分片。不提及mesh轴名称表示承诺输出值在该 mesh 轴上相等,并且应该生成单个值而不是连接。check_rep (bool) – 如果为 True(默认),则启用额外的有效性检查和自动微分优化。有效性检查涉及未在
out_specs中提及的任何 mesh 轴名称是否与f的输出如何复制一致。如果在f中使用 Pallas 内核,则必须设置为 False。auto (frozenset[Hashable]) – (实验性) 来自
mesh的轴名称的可选集合,我们不会在这些轴上对数据进行分片或映射函数,而是允许编译器控制分片。这些名称不能在in_specs、out_specs或f中的通信集合中使用。
- 返回:
一个可调用对象,用于在根据
mesh和in_specs分片的数据上应用输入函数f。
示例
有关示例,请参阅 并行计算简介 或 使用 shard_map 的 SPMD 多设备并行。