jax.numpy.histogram#
- jax.numpy.histogram(a, bins=10, range=None, weights=None, density=None)[源代码]#
计算一维直方图。
JAX 实现
numpy.histogram()。- 参数:
a (ArrayLike) – 需要分箱的值数组。可以是任何大小或维度。
bins (ArrayLike) – 指定直方图中 bin 的数量(默认值:10)。
bins也可以是指定 bin 边界的数组。range (Sequence[ArrayLike] | None) – 标量元组。指定数据的范围。如果未指定,则从数据中推断范围。
weights (ArrayLike | None) – 指定数据点权重的可选数组。应与
a兼容并可广播。如果未指定,则每个数据点加权相等。density (bool | None) – 如果为 True,则返回单位长度计数的归一化直方图。如果为 False(默认值),则返回每个 bin 的(加权)计数。
- 返回:
一个元组,包含数组
(histogram, bin_edges),其中histogram包含聚合数据,bin_edges指定 bin 的边界。- 返回类型:
另请参阅
jax.numpy.bincount(): 计算数组中每个值出现的次数。jax.numpy.histogram2d(): 计算二维数组的直方图。jax.numpy.histogramdd(): 计算 N 维数组的直方图。jax.numpy.histogram_bin_edges():计算直方图的 bin 边缘。
示例
>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) >>> counts, bin_edges = jnp.histogram(a, bins=8) >>> print(counts) [3. 0. 0. 2. 1. 0. 1. 1.] >>> print(bin_edges) [ 1. 4. 7. 10. 13. 16. 19. 22. 25.]
指定 bin 范围
>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5) >>> print(counts) [3. 0. 2. 2. 1.] >>> print(bin_edges) [ 0. 5. 10. 15. 20. 25.]
显式指定 bin 边界
>>> bin_edges = jnp.array([0, 10, 20, 30]) >>> counts, _ = jnp.histogram(a, bins=bin_edges) >>> print(counts) [3. 4. 1.]
使用
density=True返回一个归一化的直方图>>> density, bin_edges = jnp.histogram(a, density=True) >>> dx = jnp.diff(bin_edges) >>> normed_sum = jnp.sum(density * dx) >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)