jax.nn.standardize#
- jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[源代码]#
将输入标准化为零均值和单位方差。
标准化计算如下:
\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]其中 \(\langle x\rangle\) 表示 \(x\) 的均值,\(\epsilon\) 是一个引入的用于避免除以零的小修正因子。
- 参数:
x (ArrayLike) – 要标准化的输入数组。
axis (Axis) – 用于标准化的轴的整数或整数元组。默认为最后一个轴 (
-1)。mean (ArrayLike | None) – 可选,指定用于标准化的均值。如果未指定,则将使用
x.mean(axis, where=where)。variance (ArrayLike | None) – 可选,指定用于标准化的方差。如果未指定,则将使用
x.var(axis, where=where)。epsilon (ArrayLike) – 添加到方差以避免除以零的修正因子;默认为
1E-5。where (ArrayLike | None) – 可选的布尔掩码,用于在计算均值和方差时指定要使用的元素。
- 返回:
一个与
x形状相同的数组,包含标准化后的输入。- 返回类型: