jax.numpy.choose#
- jax.numpy.choose(a, choices, out=None, mode='raise')[源代码]#
通过堆叠选择数组的切片来构造数组。
JAX 对
numpy.choose()的实现。此函数的作用可能令人困惑,但在最简单的情况下,当
a是一个一维数组,choices是一个二维数组,并且a的所有条目都在边界内(即0 <= a_i < len(choices)),那么该函数等同于以下操作:def choose(a, choices): return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])
在更一般的情况下,
a可以具有任意数量的维度,并且choices可以是任意序列的广播兼容数组。在这种情况下,同样对于边界内的索引,其逻辑等同于:def choose(a, choices): a, *choices = jnp.broadcast_arrays(a, *choices) choices = jnp.array(choices) return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])
唯一的额外复杂性来自于
mode参数,该参数控制a中越界索引的行为,如下所述。- 参数:
- 返回:
一个数组,其中包含在
a指定的索引处从choices堆叠的切片。结果的形状是broadcast_shapes(a.shape, *(c.shape for c in choices))。- 返回类型:
另请参阅
jax.lax.switch():根据索引在 N 个函数之间进行选择。
示例
这是 1D 索引数组和 2D 选择数组的最简单情况,在这种情况下,它会从每列中选择索引值
>>> choices = jnp.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [ 9, 10, 11, 12]]) >>> a = jnp.array([2, 0, 1, 0]) >>> jnp.choose(a, choices) Array([9, 2, 7, 4], dtype=int32)
mode参数指定如何处理越界索引;选项包括wrap或clip>>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound >>> jnp.choose(a2, choices, mode='clip') Array([ 9, 2, 7, 12], dtype=int32) >>> jnp.choose(a2, choices, mode='wrap') Array([9, 2, 7, 8], dtype=int32)
在更一般的情况下,
choices可以是具有任意广播兼容形状的类数组对象的序列。>>> choice_1 = jnp.array([1, 2, 3, 4]) >>> choice_2 = 99 >>> choice_3 = jnp.array([[10], ... [20], ... [30]]) >>> a = jnp.array([[0, 1, 2, 0], ... [1, 2, 0, 1], ... [2, 0, 1, 2]]) >>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap') Array([[ 1, 99, 10, 4], [99, 20, 3, 99], [30, 2, 99, 30]], dtype=int32)