jax.experimental.shard_map.shard_map#
- jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[source]#
在数据分片上映射函数。
注意
shard_map
是一个实验性 API,仍可能发生变化。有关分片数据的介绍,请参阅并行编程入门。有关更深入地了解如何使用shard_map
,请参阅使用 shard_map 的 SPMD 多设备并行。- 参数:
f (Callable) – 要映射的可调用对象。
f
的每次应用,或f
的“实例”,都将映射参数的分片作为输入,并生成输出的分片。mesh (Mesh | AbstractMesh) – 一个
jax.sharding.Mesh
,表示用于对数据进行分片并在其上执行f
实例的设备数组。Mesh
的名称可用于f
中的集合通信操作。这通常由诸如jax.experimental.mesh_utils.create_device_mesh()
等实用程序函数创建。in_specs (Specs) – 一个 pytree,其叶子节点为
PartitionSpec
实例,其树结构是要映射到的 args 元组的树前缀。类似于NamedSharding
,每个PartitionSpec
表示相应的参数(或参数子树)应如何沿mesh
的命名轴进行分片。在每个PartitionSpec
中,在某个位置提及mesh
轴名称表示沿该位置轴分片相应的参数数组轴;不提及轴名称表示复制。如果参数或参数子树的对应规范为 None,则该参数不分片。out_specs (Specs) – 一个 pytree,其叶子节点为
PartitionSpec
实例,其树结构是f
输出的树前缀。每个PartitionSpec
表示应如何连接相应的输出分片。在每个PartitionSpec
中,在某个位置提及mesh
轴名称表示沿相应的位置轴连接该 mesh 轴的分片。不提及mesh
轴名称表示承诺输出值沿该 mesh 轴相等,并且应该只生成一个值而不是连接。check_rep (bool) – 如果为 True(默认值),则启用额外的有效性检查和自动微分优化。有效性检查涉及未在
out_specs
中提及的任何 mesh 轴名称是否与f
的输出如何复制一致。如果在f
中使用 Pallas 内核,则必须设置为 False。auto (frozenset[AxisName]) – (实验性)来自
mesh
的可选轴名称集合,我们不跨这些轴名称分片数据或映射函数,而是允许编译器控制分片。这些名称不能用于in_specs
、out_specs
或f
中的通信集合。
- 返回:
一个可调用对象,它根据
mesh
和in_specs
在分片数据上应用输入函数f
。
示例
有关示例,请参阅并行编程入门或使用 shard_map 的 SPMD 多设备并行。