jax.numpy.average#
- jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[源代码]#
计算加权平均值。
JAX 实现
numpy.average()。- 参数:
a (ArrayLike) – 要计算平均值的数组
axis (Axis) – 一个可选的整数或整数序列,指定计算平均值的轴。如果未指定,则沿所有轴计算平均值。
weights (ArrayLike | None) – 加权平均的可选权重数组。它必须完全匹配 a 的形状,或者如果指定了 axis,则对于单个轴,它必须具有形状
a.shape[axis],或者对于多个轴,它必须具有形状tuple(a.shape[ax] for ax in axis)。returned (bool) – 如果为 False(默认),则仅返回平均值。如果为 True,则同时返回平均值和归一化因子(即权重的总和)。
keepdims (bool) – 如果为 True,则减少的轴在结果中保留大小为 1。如果为 False(默认),则减少的轴会被压缩掉。
- 返回:
一个数组
average,或者如果returned为 True,则是一个数组元组(average, normalization)。- 返回类型:
另请参阅
jax.numpy.mean():无权平均值。
示例
简单平均
>>> x = jnp.array([1, 2, 3, 2, 4]) >>> jnp.average(x) Array(2.4, dtype=float32)
加权平均
>>> weights = jnp.array([2, 1, 3, 2, 2]) >>> jnp.average(x, weights=weights) Array(2.5, dtype=float32)
使用
returned=True来选择性地返回归一化因子,即权重的总和>>> jnp.average(x, returned=True) (Array(2.4, dtype=float32), Array(5., dtype=float32)) >>> jnp.average(x, weights=weights, returned=True) (Array(2.5, dtype=float32), Array(10., dtype=float32))
沿指定轴的加权平均
>>> x = jnp.array([[8, 2, 7], ... [3, 6, 4]]) >>> weights = jnp.array([1, 2, 3]) >>> jnp.average(x, weights=weights, axis=1) Array([5.5, 4.5], dtype=float32)