jax.lax.fori_loop#
- jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[source]#
通过归约到
jax.lax.while_loop()
,从lower
循环到upper
。简而言之,Haskell 风格的类型签名为
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
fori_loop
的语义由以下 Python 实现给出def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
正如 Python 版本所暗示的,将
upper <= lower
会导致不执行任何迭代。不支持负值或自定义增量。与该 Python 版本不同的是,
fori_loop
是通过调用jax.lax.while_loop()
或调用jax.lax.scan()
来实现的。如果循环次数是静态的(即在追踪时已知,可能是因为lower
和upper
是 Python 整数字面量),则fori_loop
通过scan()
实现,并支持反向模式自动微分;否则,将使用while_loop
,且不支持反向模式自动微分。请参阅这些函数的文档字符串以获取更多信息。同样与 Python 对应物不同的是,循环传递值
val
在所有迭代中必须保持固定的形状和数据类型(而不仅仅是例如在 NumPy 秩/形状广播和数据类型提升规则下保持一致)。换句话说,上述类型签名中的类型a
表示一个具有固定形状和数据类型的数组(或者一个具有固定结构以及叶子处具有固定形状和数据类型数组的嵌套元组/列表/字典容器数据结构)。注意
fori_loop()
编译body_fun
,因此虽然它可以与jit()
结合使用,但这通常不是必需的。- 参数:
- 返回:
来自最后一次迭代的循环值,类型为
a
。