jax.numpy.argsort#
- jax.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[源码]#
返回一个数组的排序索引。
JAX 对
numpy.argsort()
的实现。- 参数:
a (Array | ndarray | bool | number | bool | int | float | complex) – 要排序的数组
axis (int | None) – 排序所依据的整数轴。默认为
-1
,即最后一个轴。如果为None
,则a
在排序前会被展平。stable (bool) – 指定是否使用稳定排序的布尔值。默认值=True。
descending (bool) – 指定是否按降序排序的布尔值。默认值=False。
kind (None) – 已弃用;请改为使用 stable=True 或 stable=False 指定排序算法。
order (None) – JAX 不支持。
- 返回:
返回的数组包含排序索引。返回的数组形状将为
a.shape
(如果axis
是整数)或形状为(a.size,)
(如果axis
为 None)。- 返回类型:
示例
简单的一维排序
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> indices = jnp.argsort(x) >>> indices Array([0, 5, 4, 1, 3, 2], dtype=int32) >>> x[indices] Array([1, 1, 2, 3, 4, 5], dtype=int32)
沿数组的最后一个轴进行排序
>>> x = jnp.array([[2, 1, 3], ... [6, 4, 3]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 0, 2], [2, 1, 0]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
另请参阅
jax.numpy.sort()
: 直接返回排序后的值。jax.numpy.lexsort()
: 多个数组的字典序排序。jax.lax.sort()
: 封装 XLA Sort 运算符的底层函数。