jax.nn 模块#

神经网络库的常用函数。

激活函数#

relu

ReLU(Rectified linear unit)激活函数。

relu6

ReLU 6 激活函数。

sigmoid(x)

Sigmoid 激活函数。

softplus(x)

Softplus 激活函数。

sparse_plus(x)

Sparse plus 函数。

sparse_sigmoid(x)

Sparse sigmoid 激活函数。

soft_sign(x)

Soft-sign 激活函数。

silu(x)

SiLU(又称 swish)激活函数。

swish(x)

SiLU(又称 swish)激活函数。

log_sigmoid(x)

Log-sigmoid 激活函数。

leaky_relu(x[, negative_slope])

Leaky ReLU(Leaky rectified linear unit)激活函数。

hard_sigmoid(x)

Hard Sigmoid 激活函数。

hard_silu(x)

硬 SiLU (swish) 激活函数

hard_swish(x)

硬 SiLU (swish) 激活函数

hard_tanh(x)

Hard \(\mathrm{tanh}\) 激活函数。

elu(x[, alpha])

ELU(Exponential linear unit)激活函数。

celu(x[, alpha])

连续可微的指数线性单元激活函数。

selu(x)

SELU(Scaled exponential linear unit)激活函数。

gelu(x[, approximate])

GELU(Gaussian error linear unit)激活函数。

glu(x[, axis])

GLU(Gated linear unit)激活函数。

squareplus(x[, b])

Squareplus 激活函数。

mish(x)

Mish 激活函数。

identity(x)

Identity 激活函数。

其他函数#

softmax(x[, axis, where])

Softmax 函数。

log_softmax(x[, axis, where])

Log-Softmax 函数。

logmeanexp(x[, axis, where, keepdims])

Log mean exp。

logsumexp()

对数和指数缩减。

standardize(x[, axis, mean, variance, ...])

将输入标准化为零均值和单位方差。

one_hot(x, num_classes, *[, dtype, axis])

对给定的索引进行独热编码。

dot_product_attention(query, key, value[, ...])

点积注意力函数。

scaled_matmul(lhs, rhs, lhs_scales, rhs_scales)

缩放矩阵乘法函数。

get_scaled_dot_general_config(mode[, ...])

获取 scaled_dot_general 的量化配置。

scaled_dot_general(lhs, rhs, dimension_numbers)

缩放点积通用操作。

log1mexp

\(\log(1 - \exp(-x))\) 进行数值稳定的计算。