jax.nn.initializers 模块#

常用的神经网络层初始化器,与 Keras 和 Sonnet 中使用的定义一致。

初始化器#

本模块提供了常用的神经网络层初始化器,与 Keras 和 Sonnet 中使用的定义一致。

初始化器是一个函数,它接受三个参数: (key, shape, dtype),并返回一个具有 shape 维度和 dtype 数据类型的数组。 参数 key 是一个 PRNG 密钥(例如来自 jax.random.key()),用于生成随机数来初始化数组。

constant(value[, dtype])

构建一个返回全为常量 value 的数组的初始化器。

delta_orthogonal([scale, column_axis, dtype])

构建一个用于 delta 正交核的初始化器。

glorot_normal([in_axis, out_axis, ...])

构建一个 Glorot 正态初始化器(也称为 Xavier 正态初始化器)。

glorot_uniform([in_axis, out_axis, ...])

构建一个 Glorot 均匀初始化器(也称为 Xavier 均匀初始化器)。

he_normal([in_axis, out_axis, batch_axis, dtype])

构建一个 He 正态初始化器(也称为 Kaiming 正态初始化器)。

he_uniform([in_axis, out_axis, batch_axis, ...])

构建一个 He 均匀初始化器(也称为 Kaiming 均匀初始化器)。

lecun_normal([in_axis, out_axis, ...])

构建一个 Lecun 正态初始化器。

lecun_uniform([in_axis, out_axis, ...])

构建一个 Lecun 均匀初始化器。

normal([stddev, dtype])

构建一个返回实数正态分布随机数组的初始化器。

ones(key, shape[, dtype, out_sharding])

返回一个全为 1 的常量数组的初始化器。

orthogonal([scale, column_axis, dtype])

构建一个返回均匀分布的正交矩阵的初始化器。

truncated_normal([stddev, dtype, lower, upper])

构建一个返回截断正态分布随机数组的初始化器。

uniform([scale, dtype])

构建一个返回实数均匀分布随机数组的初始化器。

variance_scaling(scale, mode, distribution)

根据权重张量的形状调整其尺度的初始化器。

zeros(key, shape[, dtype, out_sharding])

返回一个全为 0 的常量数组的初始化器。