jax.lax.switch#
- jax.lax.switch(index, branches, *operands, operand=<object object>)[源代码]#
根据
index应用branches中的一个且仅一个分支。如果
index超出范围,则会将其夹紧到范围内。具有以下 Python 的语义:
def switch(index, branches, *operands): index = clamp(0, index, len(branches) - 1) return branches[index](*operands)
内部来说,这包装了 XLA 的 Conditional 操作符。但是,当使用
vmap()转换为操作一批谓词时,cond会被转换为select()。- 参数:
index – 整数标量类型,指示要应用的分支函数。
branches (Sequence[Callable]) – 函数序列(A -> B),根据
index应用。所有分支都必须返回相同的输出结构。operands – 输入给所应用分支的操作数(A)。
- 返回:
根据
index选择的分支branch(*operands)的值(B)。