jax.lax.scan#
- jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)[source]#
在沿状态传递时,扫描函数在数组前导轴上。
简而言之,Haskell 风格的类型签名为
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
其中,对于任何数组类型说明符
t
,[t]
表示具有额外前导轴的类型;如果t
是一个带数组叶子的 pytree(容器)类型,那么[t]
表示具有相同 pytree 结构且对应叶子各带一个额外前导轴的类型。当
xs
的类型(上面表示为 a)为数组类型或 None,且ys
的类型(上面表示为 b)为数组类型时,scan()
的语义大致由以下 Python 实现给出def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
与该 Python 版本不同,
xs
和ys
都可以是任意 pytree 值,因此可以一次性扫描多个数组并生成多个输出数组。None
实际上是这种情况的一个特例,因为它表示一个空的 pytree。此外,与该 Python 版本不同,
scan()
是一个 JAX 原语,并被降低为单个 WhileOp。这对于减少 JIT 编译函数的编译时间非常有用,因为jit()
函数中的原生 Python 循环构造会被展开,导致大型 XLA 计算。最后,循环携带值
carry
必须在所有迭代中保持固定的形状和数据类型(而不仅仅是符合 NumPy 的秩/形状广播和数据类型提升规则,例如)。换句话说,上面类型签名中的类型c
表示一个具有固定形状和数据类型的数组(或一个具有固定结构且在叶子处具有固定形状和数据类型的数组的嵌套元组/列表/字典容器数据结构)。注意
scan()
会编译f
,所以虽然它可以与jit()
结合使用,但通常没有必要。注意
scan()
旨在用于静态迭代次数的循环。对于动态迭代次数的循环,请使用fori_loop()
或while_loop()
。- 参数:
f (Callable[[Carry, X], tuple[Carry, Y]]) – 一个待扫描的 Python 函数,类型为
c -> a -> (c, b)
,意味着f
接受两个参数,第一个是循环携带值,第二个是xs
沿其前导轴的一个切片,并且f
返回一个二元组,其中第一个元素代表循环携带的新值,第二个代表输出的一个切片。init (Carry) – 类型为
c
的初始循环携带值,它可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典)形式,表示初始循环携带值。此值必须与f
返回的二元组的第一个元素的结构相同。xs (X | None) – 类型为
[a]
的值,用于沿前导轴进行扫描,其中[a]
可以是数组或任何具有一致前导轴大小的 pytree(嵌套的 Python 元组/列表/字典)形式。length (int | None) – 可选整数,指定循环迭代次数,其必须与
xs
中数组前导轴的大小一致(但可用于在不需要输入xs
的情况下执行扫描)。reverse (bool) – 可选布尔值,指定是向前(默认)还是向后运行扫描迭代,这等同于反转
xs
和ys
中数组的前导轴。unroll (int | bool) – 可选正整数或布尔值,在扫描原语的底层操作中,指定在单个循环迭代中展开多少个扫描迭代。如果提供整数,它将决定在循环的单个卷绕迭代中运行多少个展开的循环迭代。如果提供布尔值,它将决定循环是完全展开(即 unroll=True)还是保持完全卷绕(即 unroll=False)。
_split_transpose (bool) – 实验性可选布尔值,指定是否将转置进一步拆分为扫描(计算激活梯度)和映射(计算与数组参数对应的梯度)。启用此功能可能会增加内存需求,因此这是一个实验性功能,可能会发展或甚至被回滚。
- 返回:
一个类型为
(c, [b])
的二元组,其中第一个元素表示最终循环携带值,第二个元素表示在输入的前导轴上扫描f
的第二个输出的堆叠结果。- 返回类型:
tuple[Carry, Y]