jax.nn.logsumexp#
- jax.nn.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array[源代码]#
- jax.nn.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
- jax.nn.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]
对数和指数缩减。
JAX 实现
scipy.special.logsumexp()。\[\operatorname{logsumexp} a = \log \sum_i b_i \exp a_i\]其中 \(i\) 索引遍历一个或多个要约减的维度。
- 参数:
a – 输入数组
axis – 整数或整数序列,默认为 None。要计算总和的轴。如果为 None,则沿所有轴计算总和。
b – 指数的缩放因子。必须可广播到 a 的形状。
keepdims – 如果
True,则约减的轴将保留在输出中,大小为 1。return_sign – 如果
True,则输出将是(result, sign)对,其中sign是总和的符号,result包含其绝对值的对数。如果False,则仅返回result,如果总和为负,则其中包含 NaN 值。where – 要包含在约减中的元素。
- 返回:
根据
return_sign参数的值,可以是数组result,也可以是数组对(result, sign)。
另请参阅