jax.lax.map#
- jax.lax.map(f, xs, *, batch_size=None)[源码]#
在领先数组轴上映射函数。
类似于 Python 的内置 map 函数,但输入和输出是堆叠数组的形式。建议改用
vmap()
变换,除非你需要逐元素应用函数以减少内存使用或与其他控制流原语进行异构计算。当
xs
是数组类型时,map()
的语义由以下 Python 实现定义def map(f, xs): return np.stack([f(x) for x in xs])
与
scan()
类似,map()
是通过 JAX 原语实现的,因此与 Python 循环相比,它具有许多相同的优势:xs
可以是任意嵌套的 pytree 类型,并且映射计算只编译一次。如果提供了
batch_size
,计算将以该大小的批次执行,并使用vmap()
进行并行化。这可以作为map
的更优性能版本,也可以作为vmap
的内存高效版本使用。如果轴不能被批次大小整除,则剩余部分会在单独的vmap
中处理并连接到结果中。>>> x = jnp.ones((10, 3, 4)) >>> def f(x): ... print('inner shape:', x.shape) ... return x + 1 >>> y = lax.map(f, x, batch_size=3) inner shape: (3, 4) inner shape: (3, 4) >>> y.shape (10, 3, 4)
在上面的例子中,“inner shape”被打印了两次,一次是在追踪批处理计算时,另一次是在追踪剩余计算时。
- 参数:
f – 一个 Python 函数,用于在
xs
的第一个轴或多个轴上逐元素应用。xs – 沿领先轴进行映射的值。
batch_size (int | None) – (可选)指定每个步骤并行执行的批次大小的整数。
- 返回:
映射的值。