jax.example_libraries.stax 模块#

Stax 是一个小型但灵活的神经网络规范库,从零开始构建。

您可能并不想导入此模块!Stax 仅用作示例库。JAX 还有许多其他功能更全面的神经网络库,包括 Google 的 Flax 和 DeepMind 的 Haiku

jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

用于池化层的层构建函数。

jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[source]#

用于批归一化层的层构建函数。

jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

用于通用卷积层的层构建函数。

jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

用于通用转置卷积层的层构建函数。

jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

用于通用转置卷积层的层构建函数。

jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]#

用于密集(全连接)层的层构造函数。

jax.example_libraries.stax.Dropout(rate, mode='train')[source]#

具有给定比率的 dropout 层的层构建函数。

jax.example_libraries.stax.FanInConcat(axis=-1)[source]#

扇入连接层的层构建函数。

jax.example_libraries.stax.FanOut(num)[source]#

扇出层的层构建函数。

jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#

用于通用卷积层的层构建函数。

jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#

用于通用转置卷积层的层构建函数。

jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

用于池化层的层构建函数。

jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

用于池化层的层构建函数。

jax.example_libraries.stax.elementwise(fun, **fun_kwargs)[source]#

在其输入上逐元素应用标量函数的层。

jax.example_libraries.stax.parallel(*layers)[source]#

用于并行组合层的组合器。

由此组合器产生的层通常与 FanOut 和 FanInSum 层一起使用。

参数:

*layers – 一系列层,每个层都是一个 (init_fun, apply_fun) 对。

返回:

一个新的层,意味着一个 (init_fun, apply_fun) 对,表示给定层序列的并行组合。 特别是,返回的层接受一系列输入,并返回与参数 layers 长度相同的输出序列。

jax.example_libraries.stax.serial(*layers)[source]#

用于串行组合层的组合器。

参数:

*layers – 一系列层,每个层都是一个 (init_fun, apply_fun) 对。

返回:

一个新的层,意味着一个 (init_fun, apply_fun) 对,表示给定层序列的串行组合。

jax.example_libraries.stax.shape_dependent(make_layer)[source]#

将层构造器对延迟到输入形状已知时才进行的组合器。

参数:

make_layer – 一个单参数函数,它接受输入形状作为参数(正整数的元组),并返回一个 (init_fun, apply_fun) 对。

返回:

一个新的层,意味着一个 (init_fun, apply_fun) 对,表示与 make_layer 返回的层相同的层,但其构造被延迟到输入形状已知时。