jax.smap#
- jax.smap(f=None, /, *, in_axes=jax.sharding.Infer, out_axes, axis_name)[source]#
单轴 shard_map,用于一次映射函数 f 的一个轴。
- 参数:
f – 待映射的可调用对象。每次应用
f
,即f
的“实例”,都将映射参数的分片作为输入,并生成输出的分片。in_axes – (可选)一个整数、None 或值序列,指定要映射的输入数组轴。如果未指定,smap 将仅在 Explicit 模式下尝试从参数推断轴。一个整数或
None
表示所有参数要映射的数组轴(None
表示不映射任何轴),元组则表示每个相应的位置参数要映射的轴。对于每个数组,轴整数必须在[-ndim, ndim)
范围内,其中ndim
是相应输入数组的维度(轴)数量。out_axes – 一个整数、None 或(嵌套的)标准 Python 容器(元组/列表/字典),指示映射轴应出现在输出的何处。
axis_name (AxisName) –
mesh
轴的名称,函数f
在其上手动操作。
- 返回:
一个可调用对象,表示
f
的映射版本,它接受与f
对应的位置参数,并生成与f
对应的输出。