jax.numpy.partition#
- jax.numpy.partition(a, kth, axis=-1)[源代码]#
返回数组的部分排序副本。
JAX 实现
numpy.partition()。JAX 版本与 NumPy 在处理 NaN 条目方面有所不同:设置了负位数的 NaN 会被排序到数组的开头。- 参数:
- 返回:
沿
axis在kth值处分区的a的副本。kth之前的值小于take(a, kth, axis),而kth之后的值的索引大于take(a, kth, axis)- 返回类型:
注意
JAX 版本要求
kth参数是一个静态整数,而不是一个通用数组。这是通过两次调用jax.lax.top_k()来实现的。如果您只访问输出的最高或最低 k 个值,直接调用jax.lax.top_k()可能会更有效。另请参阅
jax.numpy.sort():完全排序jax.numpy.argpartition():间接部分排序jax.lax.top_k():直接查找前 k 个条目jax.lax.approx_max_k():计算近似前 k 个条目jax.lax.approx_min_k():计算近似后 k 个条目
示例
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> x_partitioned = jnp.partition(x, kth) >>> x_partitioned Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
结果是输入的部分排序副本。
kth之前的所有值都小于枢轴值,而kth之后的所有值都大于枢轴值>>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [9 8 7 6 5]
请注意,在
smallest_values和largest_values之间,返回的顺序是任意的,并且取决于实现。