jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#

在已排序数组中执行二分搜索。

numpy.searchsorted() 的 JAX 实现。

这将返回排序数组 a 中可以插入值 v 以保持其排序顺序的索引。

参数:
  • a (ArrayLike) – 一维数组,假定已排序,除非指定了 sorter

  • v (ArrayLike) – N 维查询值数组

  • side (str) – 'left'(默认)或 'right';指定在出现并列情况下,插入索引将位于左侧还是右侧。

  • sorter (ArrayLike | None) – 可选的索引数组,用于指定 a 的排序顺序。如果指定,则算法假定 a[sorter] 已排序。

  • method (str) – 'scan'(默认)、'scan_unrolled''sort''compare_all' 之一。请参阅下面的“注意”。

返回值:

形状为 v.shape 的插入索引数组。

返回类型:

数组

注意

method 参数控制用于计算插入索引的算法。

  • 'scan'(默认)在 CPU 上往往性能更高,尤其是在 a 非常大时。

  • 'scan_unrolled' 在 GPU 上性能更高,但会增加编译时间。

  • 'sort' 在 GPU 和 TPU 等加速器后端上通常性能更高,尤其是在 v 非常大时。

  • a 非常小时,'compare_all' 往往性能最高。

示例

搜索单个值

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

搜索一批值

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

(可选)可以使用 sorter 参数来查找通过 jax.numpy.argsort() 排序的数组的插入索引

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

结果等同于传递排序后的数组

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)