jax.numpy.vecmat#

jax.numpy.vecmat(x1, x2, /)[source]#

批量共轭向量-矩阵乘积。

numpy.vecmat() 的 JAX 实现。

参数:
返回:

形状为 (..., N) 的数组,包含批量共轭向量-矩阵乘积。

返回类型:

Array

另请参阅

示例

简单向量-矩阵乘积

>>> x1 = jnp.array([[1, 2, 3]])
>>> x2 = jnp.array([[4, 5],
...                 [6, 7],
...                 [8, 9]])
>>> jnp.vecmat(x1, x2)
Array([[40, 46]], dtype=int32)

批量向量-矩阵乘积

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> jnp.vecmat(x1, x2)
Array([[ 40,  46],
       [ 94, 109]], dtype=int32)