jax.lax.map#

jax.lax.map(f, xs, *, batch_size=None)[源代码]#

在领先的数组轴上映射函数。

类似于 Python 的内置 map,除了输入和输出是以堆叠数组的形式。 考虑使用 vmap() 转换,除非您需要逐元素应用函数以减少内存使用量或与其他控制流原语进行异构计算。

xs 是数组类型时,map() 的语义由这个 Python 实现给出

def map(f, xs):
  return np.stack([f(x) for x in xs])

scan() 一样,map() 是用 JAX 原语实现的,因此 Python 循环的许多相同优势也适用:xs 可以是任意嵌套的 pytree 类型,并且映射的计算只编译一次。

如果提供了 batch_size,则计算将以该大小的批次执行,并使用 vmap() 进行并行化。 这可以用作 map 的更高效版本或 vmap 的内存高效版本。 如果轴不能被批次大小整除,则剩余部分将在单独的 vmap 中处理并连接到结果中。

>>> x = jnp.ones((10, 3, 4))
>>> def f(x):
...   print('inner shape:', x.shape)
...   return x + 1
>>> y = lax.map(f, x, batch_size=3)
inner shape: (3, 4)
inner shape: (3, 4)
>>> y.shape
(10, 3, 4)

在上面的示例中,“inner shape” 打印了两次,一次是在跟踪批处理计算时,另一次是在跟踪剩余计算时。

参数:
  • f – 一个 Python 函数,用于在 xs 的第一个轴或多个轴上逐元素应用。

  • xs – 要在其前导轴上映射的值。

  • batch_size (int | None) – (可选) 整数,指定每个步骤并行执行的批次大小。

返回:

映射的值。