jax.nn.initializers.variance_scaling#
- jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=None)[源代码]#
根据权重张量的形状调整其尺度的初始化器。
当
distribution="truncated_normal"或distribution="normal"时,从中(截断)正态分布中抽取样本,均值为零,标准差(截断后,如果适用)为 \(\sqrt{\frac{scale}{n}}\),其中 n 对于每种mode"fan_in":输入的数量"fan_out":输出的数量"fan_avg":输入和输出数量的算术平均值"fan_geo_avg":输入和输出数量的几何平均值
此初始化器可以使用
in_axis、out_axis和batch_axis进行配置,以处理通用的卷积层或全连接层;不属于任何这些参数的轴被假定为“感受野”(卷积核的空间轴)。当
distribution="truncated_normal"时,样本的绝对值在缩放前会被截断为 2 个标准差。当
distribution="uniform"时,从中抽取样本如果 dtype 是实数,则从一个均匀区间中抽取;或者
如果 dtype 是复数,则从一个均匀圆盘中抽取;
均值为零,标准差为 \(\sqrt{\frac{scale}{n}}\),其中 n 定义如上。
- 参数:
scale (RealNumeric) – 缩放因子(正浮点数)。
mode (Literal['fan_in'] | Literal['fan_out'] | Literal['fan_avg'] | Literal['fan_geo_avg']) –
"fan_in"、"fan_out"、"fan_avg"和"fan_geo_avg"中的一个。distribution (Literal['truncated_normal'] | Literal['normal'] | Literal['uniform']) – 要使用的随机分布。在
"truncated_normal"、"normal"和"uniform"中选择一个。dtype (DTypeLikeInexact | None) – 权重的 dtype。
- 返回类型:
初始化器