jax.numpy.matvec#

jax.numpy.matvec(x1, x2, /)[源代码]#

批量矩阵-向量积。

numpy.matvec() 的 JAX 实现。

参数:
返回:

形状为 (..., M) 的数组,包含批量矩阵-向量积。

返回类型:

Array

参见

示例

简单矩阵-向量积

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> x2 = jnp.array([7, 8, 9])
>>> jnp.matvec(x1, x2)
Array([ 50, 122], dtype=int32)

批量矩阵-向量积

>>> x2 = jnp.array([[7, 8, 9],
...                 [5, 6, 7]])
>>> jnp.matvec(x1, x2)
Array([[ 50, 122],
       [ 38,  92]], dtype=int32)