jax.scipy.special.softmax#

jax.scipy.special.softmax(x, /, *, axis=None)[源代码]#

Softmax 函数。

JAX 实现 scipy.special.softmax()

计算该函数,该函数将元素重新缩放到范围 \([0, 1]\),使得 axis 上的元素总和为 \(1\)

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
参数:
  • x (ArrayLike) – 输入数组

  • axis (int | tuple[int, ...] | None) – 计算 softmax 的轴或轴。沿这些维度求和的 softmax 输出应为 \(1\)

返回:

x 的形状相同的数组。

返回类型:

Array

注意

如果任何输入值为 +inf,则结果将全部为 NaN:这反映了在浮点数数学的上下文中 inf / inf 是未定义的。

另请参阅

log_softmax()