jax.scipy.special.logsumexp#

jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array[来源]#
jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]

对数和指数缩减。

scipy.special.logsumexp() 的 JAX 实现。

\[\operatorname{logsumexp} a = \log \sum_i b_i \exp a_i\]

其中 \(i\) 索引的范围是一个或多个要减少的维度。

参数:
  • a – 输入数组

  • axis – int 或 int 序列,default=None。要计算总和的轴。 如果为 None,则沿着所有轴计算总和。

  • b – 指数的缩放因子。 必须可广播到 a 的形状。

  • keepdims – 如果 True,则缩减的轴将作为大小为 1 的维度留在输出中。

  • return_sign – 如果 True,输出将是一个 (result, sign) 对,其中 sign 是总和的符号,而 result 包含其绝对值的对数。 如果 False,则仅返回 result,如果总和为负数,它将包含 NaN 值。

  • where – 要包含在缩减中的元素。

返回:

根据 return_sign 参数的值,可以是数组 result 或数组对 (result, sign)