jax.lax.cond#

jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[源]#

有条件地应用 true_funfalse_fun

封装了 XLA 的 Conditional 运算符。

如果提供的参数类型正确,cond() 具有与此 Python 实现等效的语义,其中 pred 必须是标量类型

def cond(pred, true_fun, false_fun, *operands):
  if pred:
    return true_fun(*operands)
  else:
    return false_fun(*operands)

jax.lax.select() 相比,使用 cond 表示只有两个分支中的一个会被执行(取决于编译器的重写和优化)。然而,当使用 vmap() 转换为对一批谓词进行操作时,cond 会被转换为 select()。在所有情况下,两个分支都会被追踪(参见 关键概念:追踪 以了解 JAX 追踪模型的讨论)。

参数:
  • pred – 布尔标量类型,指示要应用哪个分支函数。

  • true_fun (可调用对象) – 函数 (A -> B),当 pred 为 True 时应用。

  • false_fun (可调用对象) – 函数 (A -> B),当 pred 为 False 时应用。

  • operands – 根据 pred 的值输入到任一分支的操作数 (A)。类型可以是标量、数组,或它们的任何PyTree(嵌套的Python元组/列表/字典)。

返回:

值为 (B),取决于 pred 的值,为 true_fun(*operands)false_fun(*operands)。类型可以是标量、数组,或它们的任何PyTree(嵌套的Python元组/列表/字典)。