jax.numpy.lexsort#
- jax.numpy.lexsort(keys, axis=-1)[source]#
按字典序对一系列键进行排序。
JAX 对
numpy.lexsort()
的实现。- 参数:
- 返回:
一个形状为
keys[0].shape
的整数数组,表示按字典序排序后的条目索引。- 返回类型:
另请参阅
jax.numpy.argsort()
:按索引排序单个条目。jax.lax.sort()
:直接的 XLA 排序 API。
示例
带有单个键的
lexsort()
等同于argsort()
>>> key1 = jnp.array([4, 2, 3, 2, 5]) >>> jnp.lexsort([key1]) Array([1, 3, 2, 0, 4], dtype=int32) >>> jnp.argsort(key1) Array([1, 3, 2, 0, 4], dtype=int32)
当有多个键时,
lexsort()
使用最后一个键作为主键。>>> key2 = jnp.array([2, 1, 1, 2, 2]) >>> jnp.lexsort([key1, key2]) Array([1, 2, 3, 0, 4], dtype=int32)
打印排序后的键时,索引的含义会更清晰。
>>> indices = jnp.lexsort([key1, key2]) >>> print(f"{key1[indices]}\n{key2[indices]}") [2 3 2 4 5] [1 1 2 2 2]
请注意,
key2
的元素按顺序出现,并且在重复值序列中,key1
的相应元素也按顺序出现。对于多维输入,
lexsort()
默认沿最后一个轴排序。>>> key1 = jnp.array([[2, 4, 2, 3], ... [3, 1, 2, 2]]) >>> key2 = jnp.array([[1, 2, 1, 3], ... [2, 1, 2, 1]]) >>> jnp.lexsort([key1, key2]) Array([[0, 2, 1, 3], [1, 3, 2, 0]], dtype=int32)
可以使用
axis
关键字选择不同的排序轴;这里我们沿最前面的轴排序。>>> jnp.lexsort([key1, key2], axis=0) Array([[0, 1, 0, 1], [1, 0, 1, 0]], dtype=int32)