jax.scipy.linalg.expm_frechet#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) tuple[Array, Array][源代码]#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) Array
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) Array | tuple[Array, Array]
计算矩阵指数的弗雷歇导数。
JAX 对
scipy.linalg.expm_frechet()的实现。- 参数:
A – 形状为
(..., N, N)的数组。E – 形状为
(..., N, N)的数组;指定导数的方向。compute_expm – 如果为 True(默认值),则计算并返回
expm(A)。method – JAX 忽略此参数。
- 返回:
如果
compute_expm为 True,则返回一个元组(expm_A, expm_frechet_AE),否则返回数组expm_frechet_AE。两个返回的数组的形状都是(..., N, N)。
示例
我们可以使用此 API 计算
A的矩阵指数,以及它在E方向上的导数。>>> key1, key2 = jax.random.split(jax.random.key(3372)) >>> A = jax.random.normal(key1, (3, 3)) >>> E = jax.random.normal(key2, (3, 3)) >>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)
这也可以通过 JAX 的自动微分方法等效地计算;在这里,我们将使用
jax.jvp()计算expm()在E方向上的导数,并得到相同的结果。>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,)) >>> jnp.allclose(expmA, expmA2) Array(True, dtype=bool) >>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2) Array(True, dtype=bool)