jax.numpy.average#

jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[source]#

计算加权平均值。

JAX 实现的 numpy.average()

参数:
  • a (ArrayLike) – 待平均的数组

  • axis (Axis) – 一个可选的整数或整数序列,指定计算平均值的轴。如果未指定,则沿所有轴计算平均值。

  • weights (ArrayLike | None) – 一个可选的加权平均权重数组。必须与 a 广播兼容。

  • returned (bool) – 如果为 False(默认值),则只返回平均值。如果为 True,则同时返回平均值和归一化因子(即权重的总和)。

  • keepdims (bool) – 如果为 True,则结果中保留被缩减的轴,大小为 1。如果为 False(默认值),则缩减的轴会被去除。

返回:

一个数组 average 或数组元组 (average, normalization) 如果 returned 为 True。

返回类型:

Array | tuple[Array, Array]

另请参阅

示例

简单平均值

>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)

加权平均值

>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)

使用 returned=True 可选择返回归一化因子,即权重的总和

>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))

沿指定轴的加权平均值

>>> x = jnp.array([[8, 2, 7],
...                [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)