jax.lax.select_n#

jax.lax.select_n(which, *cases)[源代码]#

从多个情况中选择数组值。

概括了 XLA 的 Select 运算符。与 XLA 的版本不同,该运算符是可变的,并且可以使用整数 pred 从多种情况中进行选择。

参数:
  • which (ArrayLike) – 确定应返回哪种情况。 必须是一个包含布尔值或整数值的数组。 可以是标量或具有与 cases 匹配的形状。 对于每个数组元素,which 的值确定采用哪个 caseswhich 必须在范围 [0 .. len(cases)) 中; 对于该范围之外的值,行为是实现定义的。

  • *cases (ArrayLike) – 一个非空的数组情况列表。 所有都必须具有相等的 dtype 和相等的形状。

返回:

一个形状和 dtype 等于情况的数组,其值根据 which 选择。

返回类型:

Array