jax.numpy.argpartition#
- jax.numpy.argpartition(a, kth, axis=-1)[源代码]#
返回对数组进行部分排序的索引。
JAX 实现
numpy.argpartition()。JAX 版本与 NumPy 在处理 NaN 条目时有所不同:设置了负位数的 NaN 会被排序到数组的开头。- 参数:
- 返回:
在
axis轴上,在kth值处划分a的索引。kth前面的条目是小于take(a, kth, axis)的值的索引,而kth后面的条目是大于take(a, kth, axis)的值的索引。- 返回类型:
注意
JAX 版本要求
kth参数是静态整数,而不是通用数组。这是通过两次调用jax.lax.top_k()来实现的。如果您只访问输出的顶部或底部 k 个值,直接调用jax.lax.top_k()可能会更有效。另请参阅
jax.numpy.partition():直接部分排序jax.numpy.argsort():完整间接排序jax.lax.top_k():直接查找前 k 个条目jax.lax.approx_max_k():计算近似前 k 个条目jax.lax.approx_min_k():计算近似后 k 个条目
示例
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> idx = jnp.argpartition(x, kth) >>> idx Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
结果是一个对输入进行部分排序的索引序列。
kth之前的所有索引都对应小于枢轴值的值,而kth之后的所有索引都对应大于枢轴值的值。>>> x_partitioned = x[idx] >>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [6 8 9 7 5]
请注意,在
smallest_values和largest_values之间,返回的顺序是任意的,并且取决于实现。