jax.numpy.linalg.tensorsolve#
- jax.numpy.linalg.tensorsolve(a, b, axes=None)[源代码]#
求解张量方程 a x = b 中的 x。
JAX 实现
numpy.linalg.tensorsolve()。- 参数:
- 返回:
数组 x,使得在重排
a的轴后,tensordot(a, x, x.ndim)等价于b。- 返回类型:
示例
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
现在展示
x可以使用tensordot()来重建b。>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)