jax.shard_map#

jax.shard_map(f=None, /, *, out_specs, axis_names={}, in_specs=None, mesh=None, check_vma=True)[源码]#

使用设备网格将函数映射到数据分片上。

请参阅以下文档:https://jax.net.cn/en/latest/notebooks/shard_map.html

参数:
  • f – 可映射的调用对象。f的每次应用,或称f的“实例”,将映射参数的分片作为输入,并生成输出的分片。

  • mesh (Mesh | AbstractMesh | None) – (可选,默认为 None) 一个jax.sharding.Mesh,表示用于数据分片和执行f实例的设备数组。Mesh的名称可在f中的集体通信操作中使用。如果mesh为None,它将从可通过jax.sharding.use_mesh上下文管理器设置的上下文中推断。

  • in_specs (Specs | None) – (可选,默认为 None) 一个以jax.sharding.PartitionSpec实例作为叶节点的Pytree,其树结构是待映射参数元组的树前缀。类似于jax.sharding.NamedSharding,每个PartitionSpec表示相应的参数(或参数子树)应如何沿mesh的命名轴进行分片。在每个PartitionSpec中,在某个位置提及mesh轴名称表示沿该位置轴对相应参数数组轴进行分片;不提及轴名称表示复制。如果为None,所有mesh轴必须是Explicit类型,在这种情况下,in_specs将从参数类型中推断出来。

  • out_specs (Specs) – 一个以PartitionSpec实例作为叶节点的Pytree,其树结构是f输出的树前缀。每个PartitionSpec表示相应的输出分片应如何拼接。在每个PartitionSpec中,在某个位置提及mesh轴名称表示沿相应位置轴拼接该mesh轴的分片;不提及mesh轴名称表示承诺输出值在该mesh轴上相等,并且不进行拼接只应生成单个值。

  • axis_names (Set[AxisName]) – (可选,默认为空集) mesh中轴名称的集合,函数f在这些轴上手动操作。如果为空,f将在所有mesh轴上手动操作。

  • check_vma (bool) – (可选) 布尔值(默认为 True),表示是否启用额外的有效性检查和自动微分优化。有效性检查关注out_specs中未提及的任何mesh轴名称是否与f的输出复制方式一致。

返回:

一个可调用对象,代表f的映射版本,它接受与f对应的位置参数,并生成与f对应的输出。