jax.numpy.partition#

jax.numpy.partition(a, kth, axis=-1)[源代码]#

返回数组的部分排序副本。

numpy.partition() 的 JAX 实现。 JAX 版本与 NumPy 在 NaN 条目的处理上有所不同:负位已设置的 NaN 将被排序到数组的开头。

参数:
返回:

沿 axis 在第 kth 个值处分区的 a 的副本。 kth 之前的条目是小于 take(a, kth, axis) 的值,kth 之后的条目是大于 take(a, kth, axis) 的值的索引

返回类型:

Array

注意

JAX 版本要求 kth 参数是静态整数,而不是通用数组。 这是通过两次调用 jax.lax.top_k() 实现的。 如果您只访问输出的前 k 个或后 k 个值,则直接调用 jax.lax.top_k() 可能会更有效。

另请参阅

示例

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> x_partitioned = jnp.partition(x, kth)
>>> x_partitioned
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)

结果是输入的部分排序副本。 kth 之前的所有值都小于轴心值,kth 之后的所有值都大于轴心值

>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [9 8 7 6 5]

请注意,在 smallest_valueslargest_values 中,返回的顺序是任意的且依赖于实现的。