jax.nn.standardize#
- jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#
将输入标准化为零均值和单位方差。
标准化由下式给出:
\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]其中 \(\langle x\rangle\) 表示 \(x\) 的均值,\(\epsilon\) 是引入的一个小校正因子,以避免除以零。
- 参数:
x (ArrayLike) – 要标准化的输入数组。
axis (Axis) – 整数或整数元组,表示要进行标准化的轴。默认为最后一个轴(
-1
)。mean (ArrayLike | None) – 可选地指定用于标准化的均值。如果未指定,则将使用
x.mean(axis, where=where)
。variance (ArrayLike | None) – 可选地指定用于标准化的方差。如果未指定,则将使用
x.var(axis, where=where)
。epsilon (ArrayLike) – 添加到方差的校正因子,以避免除以零;默认为
1E-5
。where (ArrayLike | None) – 可选的布尔掩码,指定计算均值和方差时要使用的元素。
- 返回:
与
x
形状相同的数组,包含标准化后的输入。- 返回类型: