jax.nn.initializers.variance_scaling#

jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=None)[源代码]#

根据权重张量的形状调整其尺度的初始化器。

distribution="truncated_normal"distribution="normal" 时,从中(截断)正态分布中抽取样本,均值为零,标准差(截断后,如果适用)为 \(\sqrt{\frac{scale}{n}}\),其中 n 对于每种 mode

  • "fan_in":输入的数量

  • "fan_out":输出的数量

  • "fan_avg":输入和输出数量的算术平均值

  • "fan_geo_avg":输入和输出数量的几何平均值

此初始化器可以使用 in_axisout_axisbatch_axis 进行配置,以处理通用的卷积层或全连接层;不属于任何这些参数的轴被假定为“感受野”(卷积核的空间轴)。

distribution="truncated_normal" 时,样本的绝对值在缩放前会被截断为 2 个标准差。

distribution="uniform" 时,从中抽取样本

  • 如果 dtype 是实数,则从一个均匀区间中抽取;或者

  • 如果 dtype 是复数,则从一个均匀圆盘中抽取;

均值为零,标准差为 \(\sqrt{\frac{scale}{n}}\),其中 n 定义如上。

参数:
  • scale (RealNumeric) – 缩放因子(正浮点数)。

  • mode (Literal['fan_in'] | Literal['fan_out'] | Literal['fan_avg'] | Literal['fan_geo_avg']) – "fan_in""fan_out""fan_avg""fan_geo_avg" 中的一个。

  • distribution (Literal['truncated_normal'] | Literal['normal'] | Literal['uniform']) – 要使用的随机分布。在 "truncated_normal""normal""uniform" 中选择一个。

  • in_axis (int | Sequence[int]) – 权重数组中输入维度的轴或轴序列。

  • out_axis (int | Sequence[int]) – 权重数组中输出维度的轴或轴序列。

  • batch_axis (int | Sequence[int]) – 权重数组中应被忽略的轴或轴序列。

  • dtype (DTypeLikeInexact | None) – 权重的 dtype。

返回类型:

初始化器