jax.scipy.linalg.sqrtm#
- jax.scipy.linalg.sqrtm(A, blocksize=1)[源代码]#
计算矩阵平方根
此函数使用
scipy.linalg.schur()
实现,该函数仅在 CPU 上受支持。scipy.linalg.sqrtm()
的 JAX 实现。- 参数:
A (ArrayLike) – 形状为
(N, N)
的数组blocksize (int) – JAX 中不支持;JAX 始终使用
blocksize=1
。
- 返回:
包含
A
的矩阵平方根的形状为(N, N)
的数组- 返回类型:
示例
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> sqrt_a = jax.scipy.linalg.sqrtm(a) >>> with jnp.printoptions(precision=2, suppress=True): ... print(sqrt_a) [[0.92+0.71j 0.54+0.j 0.92-0.71j] [0.54+0.j 1.85+0.j 0.54-0.j ] [0.92-0.71j 0.54-0.j 0.92+0.71j]]
根据定义,矩阵平方根与其自身的矩阵乘法应等于输入
>>> jnp.allclose(a, sqrt_a @ sqrt_a) Array(True, dtype=bool)
注意事项
此函数实现了 [1] 中描述的复数 Schur 方法。 由于 JAX 中尚无 Sylvester 方程求解器,因此它不使用递归分块来加速计算。
参考文献