jax.numpy.average#

jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[源代码]#

计算加权平均值。

JAX 实现 numpy.average()

参数:
  • a (ArrayLike) – 要计算平均值的数组

  • axis (Axis) – 一个可选的整数或整数序列,指定计算平均值的轴。如果未指定,则沿所有轴计算平均值。

  • weights (ArrayLike | None) – 加权平均的可选权重数组。它必须完全匹配 a 的形状,或者如果指定了 axis,则对于单个轴,它必须具有形状 a.shape[axis],或者对于多个轴,它必须具有形状 tuple(a.shape[ax] for ax in axis)

  • returned (bool) – 如果为 False(默认),则仅返回平均值。如果为 True,则同时返回平均值和归一化因子(即权重的总和)。

  • keepdims (bool) – 如果为 True,则减少的轴在结果中保留大小为 1。如果为 False(默认),则减少的轴会被压缩掉。

返回:

一个数组 average,或者如果 returned 为 True,则是一个数组元组 (average, normalization)

返回类型:

Array | tuple[Array, Array]

另请参阅

示例

简单平均

>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)

加权平均

>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)

使用 returned=True 来选择性地返回归一化因子,即权重的总和

>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))

沿指定轴的加权平均

>>> x = jnp.array([[8, 2, 7],
...                [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)