jax.nn.initializers.variance_scaling#
- jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'numpy.float64'>)[源代码]#
初始化器,根据权重张量的形状调整其尺度。
使用
distribution="truncated_normal"
或distribution="normal"
,样本从均值为零且标准差(如果适用,截断后)为 \(\sqrt{\frac{scale}{n}}\) 的(截断)正态分布中抽取,其中 n 对于每个mode
是:"fan_in"
: 输入的数量"fan_out"
: 输出的数量"fan_avg"
: 输入和输出数量的算术平均值"fan_geo_avg"
: 输入和输出数量的几何平均值
此初始化器可以使用
in_axis
、out_axis
和batch_axis
进行配置,以用于通用卷积层或密集层;未包含在这些参数中的轴假定为“感受野”(卷积核空间轴)。使用
distribution="truncated_normal"
,样本的绝对值在缩放之前被截断为 2 个标准差。使用
distribution="uniform"
,样本从以下位置抽取:如果 dtype 是实数,则为均匀区间,或者
如果 dtype 是复数,则为均匀圆盘,
均值为零且标准差为 \(\sqrt{\frac{scale}{n}}\),其中 n 定义如上。
- 参数:
scale (RealNumeric) – 缩放因子(正浮点数)。
mode (Literal['fan_in'] | Literal['fan_out'] | Literal['fan_avg'] | Literal['fan_geo_avg']) –
"fan_in"
、"fan_out"
、"fan_avg"
和"fan_geo_avg"
之一。distribution (Literal['truncated_normal'] | Literal['normal'] | Literal['uniform']) – 要使用的随机分布。
"truncated_normal"
、"normal"
和"uniform"
之一。dtype (DTypeLikeInexact) – 权重的 dtype。
- 返回类型:
初始化器