jax.lax.while_loop#
- jax.lax.while_loop(cond_fun, body_fun, init_val)[源码]#
在
cond_fun为 True 的情况下,重复调用body_fun来执行循环。简而言之,Haskell 风格的类型签名为
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
while_loop的语义由以下 Python 实现给出:def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
与该 Python 版本不同,
while_loop是一个 JAX 原始操作,会被降低为单个 WhileOp。这有助于减少 jit 编译函数的编译时间,因为@jit函数中的原生 Python 循环结构会被展开,导致生成大型 XLA 计算。同样与 Python 模拟不同的是,循环携带的值
val在所有迭代中必须保持固定的形状和 dtype(而不仅仅是与 NumPy 的 rank/shape 广播和 dtype 提升规则一致)。换句话说,上面类型签名中的类型a代表一个具有固定形状和 dtype 的数组(或者一个具有固定结构并在叶节点具有固定形状和 dtype 的数组的嵌套元组/列表/字典容器数据结构)。与使用 Python 原生循环结构另一个不同之处在于,
while_loop不支持反向模式微分,因为 XLA 计算需要静态的内存需求边界。注意
while_loop()会编译cond_fun和body_fun,因此虽然它可以与jit()结合使用,但通常是不必要的。- 参数:
cond_fun (Callable[[T], BooleanNumeric]) – 类型为
a -> Bool的函数。body_fun (Callable[[T], T]) – 类型为
a -> a的函数。init_val (T) – 类型为
a的值,该类型可以是标量、数组,或者它们的任何 pytree(嵌套的 Python 元组/列表/字典),代表初始的循环携带值。
- 返回:
body_fun 最后一次迭代的输出,类型为
a。- 返回类型:
T