jax.shard_map#
- jax.shard_map(f=None, /, *, out_specs, axis_names={}, in_specs=None, mesh=None, check_vma=True)[源码]#
使用设备网格将函数映射到数据分片上。
请参阅以下文档:https://jax.net.cn/en/latest/notebooks/shard_map.html。
- 参数:
f – 可映射的调用对象。
f
的每次应用,或称f
的“实例”,将映射参数的分片作为输入,并生成输出的分片。mesh (Mesh | AbstractMesh | None) – (可选,默认为 None) 一个
jax.sharding.Mesh
,表示用于数据分片和执行f
实例的设备数组。Mesh
的名称可在f
中的集体通信操作中使用。如果mesh为None,它将从可通过jax.sharding.use_mesh上下文管理器设置的上下文中推断。in_specs (Specs | None) – (可选,默认为 None) 一个以
jax.sharding.PartitionSpec
实例作为叶节点的Pytree,其树结构是待映射参数元组的树前缀。类似于jax.sharding.NamedSharding
,每个PartitionSpec
表示相应的参数(或参数子树)应如何沿mesh
的命名轴进行分片。在每个PartitionSpec
中,在某个位置提及mesh
轴名称表示沿该位置轴对相应参数数组轴进行分片;不提及轴名称表示复制。如果为None
,所有mesh轴必须是Explicit类型,在这种情况下,in_specs将从参数类型中推断出来。out_specs (Specs) – 一个以
PartitionSpec
实例作为叶节点的Pytree,其树结构是f
输出的树前缀。每个PartitionSpec
表示相应的输出分片应如何拼接。在每个PartitionSpec
中,在某个位置提及mesh
轴名称表示沿相应位置轴拼接该mesh轴的分片;不提及mesh
轴名称表示承诺输出值在该mesh轴上相等,并且不进行拼接只应生成单个值。axis_names (Set[AxisName]) – (可选,默认为空集)
mesh
中轴名称的集合,函数f
在这些轴上手动操作。如果为空,f
将在所有mesh轴上手动操作。check_vma (bool) – (可选) 布尔值(默认为 True),表示是否启用额外的有效性检查和自动微分优化。有效性检查关注
out_specs
中未提及的任何mesh轴名称是否与f
的输出复制方式一致。
- 返回:
一个可调用对象,代表
f
的映射版本,它接受与f
对应的位置参数,并生成与f
对应的输出。