jax.lax.top_k#

jax.lax.top_k(operand, k)[source]#

返回 operand 沿最后一个轴的 top k 个值及其索引。

参数:
  • operand (ArrayLike) – 非复数类型的 N 维数组。

  • k (int) – 整数,指定顶部条目的数量。

返回:

一个元组 (values, indices),其中

  • values 是一个数组,包含沿最后一个轴的 top k 个值。

  • indices 是一个数组,包含与值对应的索引。

返回类型:

tuple[Array, Array]

示例

查找数组中最大的三个值及其索引

>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10.,  9.,  6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)