jax.lax.cond#
- jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[来源]#
有条件地应用
true_fun
或false_fun
。包装 XLA 的 Conditional 运算符。
提供的参数类型正确,
cond()
的语义等同于此 Python 实现,其中pred
必须是标量类型。def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
与
jax.lax.select()
相比,使用cond
表示只有两个分支中的一个会被执行(在编译器重写和优化之前)。然而,当使用vmap()
转换以处理一组谓词时,cond
会被转换为select()
。在所有情况下,两个分支都会被追踪(请参阅 关键概念:追踪 了解 JAX 的追踪模型)。- 参数:
pred – 布尔标量类型,指示应用哪个分支函数。
true_fun (Callable) – 如果
pred
为 True,则应用此函数 (A -> B)。false_fun (Callable) – 如果
pred
为 False,则应用此函数 (A -> B)。operands – 要由
pred
的值决定输入到其中一个分支的操作数 (A)。类型可以是标量、数组或它们的任何 pytree(嵌套的 Python 元组/列表/字典)。
- 返回:
值 (B),根据
pred
的值,是true_fun(*operands)
的值或false_fun(*operands)
的值。类型可以是标量、数组或它们的任何 pytree(嵌套的 Python 元组/列表/字典)。