jax.smap#

jax.smap(f=None, /, *, in_axes=jax.sharding.Infer, out_axes, axis_name)[source]#

单轴 shard_map,用于一次映射函数 f 的一个轴。

参数:
  • f – 待映射的可调用对象。每次应用 f,即 f 的“实例”,都将映射参数的分片作为输入,并生成输出的分片。

  • in_axes – (可选)一个整数、None 或值序列,指定要映射的输入数组轴。如果未指定,smap 将仅在 Explicit 模式下尝试从参数推断轴。一个整数或 None 表示所有参数要映射的数组轴(None 表示不映射任何轴),元组则表示每个相应的位置参数要映射的轴。对于每个数组,轴整数必须在 [-ndim, ndim) 范围内,其中 ndim 是相应输入数组的维度(轴)数量。

  • out_axes – 一个整数、None 或(嵌套的)标准 Python 容器(元组/列表/字典),指示映射轴应出现在输出的何处。

  • axis_name (AxisName) – mesh 轴的名称,函数 f 在其上手动操作。

返回:

一个可调用对象,表示 f 的映射版本,它接受与 f 对应的位置参数,并生成与 f 对应的输出。