jax.numpy.linalg.tensorsolve#

jax.numpy.linalg.tensorsolve(a, b, axes=None)[源代码]#

求解张量方程 a x = b 中的 x。

JAX 实现 numpy.linalg.tensorsolve()

参数:
  • a (ArrayLike) – 输入数组。在根据 axes 重排(如下所述)后,形状必须为 (*b.shape, *x.shape)

  • b (ArrayLike) – 右侧数组。

  • axes (tuple[int, ...] | None) – 可选元组,用于指定 a 中应移到末尾的轴

返回:

数组 x,使得在重排 a 的轴后,tensordot(a, x, x.ndim) 等价于 b

返回类型:

Array

示例

>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
>>> b = jax.random.normal(key2, shape=(2, 2))
>>> x = jnp.linalg.tensorsolve(a, b)
>>> x.shape
(4,)

现在展示 x 可以使用 tensordot() 来重建 b

>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
>>> jnp.allclose(b, b_reconstructed)
Array(True, dtype=bool)