jax.nn.standardize#

jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#

将输入标准化为零均值和单位方差。

标准化由下式给出:

\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]

其中 \(\langle x\rangle\) 表示 \(x\) 的均值,\(\epsilon\) 是引入的一个小校正因子,以避免除以零。

参数:
  • x (ArrayLike) – 要标准化的输入数组。

  • axis (Axis) – 整数或整数元组,表示要进行标准化的轴。默认为最后一个轴(-1)。

  • mean (ArrayLike | None) – 可选地指定用于标准化的均值。如果未指定,则将使用 x.mean(axis, where=where)

  • variance (ArrayLike | None) – 可选地指定用于标准化的方差。如果未指定,则将使用 x.var(axis, where=where)

  • epsilon (ArrayLike) – 添加到方差的校正因子,以避免除以零;默认为 1E-5

  • where (ArrayLike | None) – 可选的布尔掩码,指定计算均值和方差时要使用的元素。

返回:

x 形状相同的数组,包含标准化后的输入。

返回类型:

Array