jax.lax.fori_loop#

jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[source]#

通过归约到 jax.lax.while_loop(),从 lower 循环到 upper

简而言之,Haskell 风格的类型签名

fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a

fori_loop 的语义由以下 Python 实现给出

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

正如 Python 版本所暗示的,将 upper <= lower 会导致不执行任何迭代。不支持负值或自定义增量。

与该 Python 版本不同的是,fori_loop 是通过调用 jax.lax.while_loop() 或调用 jax.lax.scan() 来实现的。如果循环次数是静态的(即在追踪时已知,可能是因为 lowerupper 是 Python 整数字面量),则 fori_loop 通过 scan() 实现,并支持反向模式自动微分;否则,将使用 while_loop,且不支持反向模式自动微分。请参阅这些函数的文档字符串以获取更多信息。

同样与 Python 对应物不同的是,循环传递值 val 在所有迭代中必须保持固定的形状和数据类型(而不仅仅是例如在 NumPy 秩/形状广播和数据类型提升规则下保持一致)。换句话说,上述类型签名中的类型 a 表示一个具有固定形状和数据类型的数组(或者一个具有固定结构以及叶子处具有固定形状和数据类型数组的嵌套元组/列表/字典容器数据结构)。

注意

fori_loop() 编译 body_fun,因此虽然它可以与 jit() 结合使用,但这通常不是必需的。

参数:
  • lower – 表示循环索引下限(包含)的整数

  • upper – 表示循环索引上限(不包含)的整数

  • body_fun – 类型为 (int, a) -> a 的函数。

  • init_val – 类型为 a 的初始循环传递值。

  • unroll (int | bool | None) – 一个可选的整数或布尔值,用于确定循环展开的程度。如果提供一个整数,它决定了循环的展开步长(即,在一个循环迭代中执行多少次原始迭代)。如果提供布尔值,它将确定循环是否完全展开(即 unroll=True)或完全不展开(即 unroll=False)。此参数仅在循环边界静态已知时适用。

返回:

来自最后一次迭代的循环值,类型为 a