jax.lax.while_loop#
- jax.lax.while_loop(cond_fun, body_fun, init_val)[来源]#
当
cond_fun
为 True 时,在循环中重复调用body_fun
。简而言之,Haskell 风格的类型签名为
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
while_loop
的语义由以下 Python 实现给出def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
与该 Python 版本不同,
while_loop
是一个 JAX 原语,并被转换为单个 WhileOp。这有助于减少 jit 编译函数的编译时间,因为在@jit
函数中的原生 Python 循环结构会被展开,导致大型的 XLA 计算。与 Python 对应物不同的是,循环传递值
val
在所有迭代中必须保持固定的形状和数据类型(而不仅仅是与 NumPy 的秩/形状广播和数据类型提升规则一致,例如)。换句话说,上述类型签名中的类型a
表示一个具有固定形状和数据类型的数组(或者是一个具有固定结构、且叶子节点处数组具有固定形状和数据类型的嵌套元组/列表/字典容器数据结构)。与使用原生 Python 循环结构相比,另一个不同之处在于
while_loop
不支持反向模式可微分,因为 XLA 计算需要内存需求的静态边界。注意
while_loop()
会编译cond_fun
和body_fun
,因此虽然它可以与jit()
结合使用,但通常没有必要。- 参数:
cond_fun (Callable[[T], BooleanNumeric]) — 类型为
a -> Bool
的函数。body_fun (Callable[[T], T]) — 类型为
a -> a
的函数。init_val (T) — 类型为
a
的值,该类型可以是标量、数组或任何 Pytree(嵌套的 Python 元组/列表/字典),表示初始的循环传递值。
- 返回:
来自 body_fun 最终迭代的输出,类型为
a
。- 返回类型:
T