jax.numpy.digitize#

jax.numpy.digitize(x, bins, right=False, *, method=None)[源代码]#

将数组转换为 bin 索引。

JAX 实现的 numpy.digitize()

参数:
  • x (ArrayLike) – 要进行数字化的值数组。

  • bins (ArrayLike) – 一维 bin 边缘数组。必须是单调递增或递减的。

  • right (bool) – 如果为 true,则区间包括右侧 bin 边缘。如果为 false(默认),则区间包括左侧 bin 边缘。

  • method (str | None) – 要传递给 searchsorted() 的可选 method 参数。 有关可用选项,请参阅该函数。

返回:

一个与 x 形状相同的整数数组,指示值所在的 bin 编号。

返回类型:

Array

参见

示例

>>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5])
>>> bins = jnp.array([1, 2, 3])
>>> jnp.digitize(x, bins)
Array([1, 2, 2, 1, 3, 3], dtype=int32)
>>> jnp.digitize(x, bins, right=True)
Array([0, 1, 2, 1, 2, 3], dtype=int32)

digitize 也支持反向排序的 bin

>>> bins = jnp.array([3, 2, 1])
>>> jnp.digitize(x, bins)
Array([2, 1, 1, 2, 0, 0], dtype=int32)