jax.numpy.histogram2d#
- jax.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)[源代码]#
计算二维直方图。
JAX 实现
numpy.histogram2d()。- 参数:
x (ArrayLike) – 用于装箱的点的 x 值的一维数组。
y (ArrayLike) – 用于装箱的点的 y 值的一维数组。
bins (ArrayLike | list[ArrayLike]) – 指定直方图的 bin 数量(默认值:10)。
bins也可以是指定 bin 边缘位置的数组,或者是一对方整数或一对数组,用于指定每个维度中的 bin 数量。range (Sequence[None | Array | Sequence[ArrayLike]] | None) – 形式为
[[xmin, xmax], [ymin, ymax]]的数组或列表对,用于指定每个维度中的数据范围。如果未指定,则范围从数据中推断。weights (ArrayLike | None) – 指定数据点权重的可选数组。应与
x和y具有相同的形状。如果未指定,则每个数据点都具有相等的权重。density (bool | None) – 如果为 True,则返回单位面积计数的归一化直方图。如果为 False(默认),则返回每个 bin 的(加权)计数。
- 返回:
一个元组
(histogram, x_edges, y_edges),其中histogram包含聚合数据,而x_edges和y_edges指定 bin 的边界。- 返回类型:
另请参阅
jax.numpy.histogram():计算一维数组的直方图。jax.numpy.histogramdd():计算 N 维数组的直方图。jax.numpy.histogram_bin_edges():计算直方图的 bin 边缘。
示例
>>> x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) >>> y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18]) >>> counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8) >>> counts.shape (8, 8) >>> x_edges Array([ 1., 4., 7., 10., 13., 16., 19., 22., 25.], dtype=float32) >>> y_edges Array([ 2., 4., 6., 8., 10., 12., 14., 16., 18.], dtype=float32)
指定 bin 范围
>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5) >>> counts.shape (5, 5) >>> x_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32) >>> y_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32)
明确指定 bin 边缘
>>> x_edges = jnp.array([0, 10, 20, 30]) >>> y_edges = jnp.array([0, 10, 20, 30]) >>> counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges]) >>> counts Array([[3, 0, 0], [1, 3, 0], [0, 1, 0]], dtype=int32)
使用
density=True返回一个归一化的直方图>>> density, x_edges, y_edges = jnp.histogram2d(x, y, density=True) >>> dx = jnp.diff(x_edges) >>> dy = jnp.diff(y_edges) >>> normed_sum = jnp.sum(density * dx[:, None] * dy[None, :]) >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)