jax.numpy.matvec#
- jax.numpy.matvec(x1, x2, /)[源码]#
批量矩阵-向量乘积。
JAX 对
numpy.matvec()的实现。- 参数:
- 返回:
形状为
(..., M)的数组,包含批量矩阵-向量乘积。- 返回类型:
另请参阅
jax.numpy.linalg.vecdot():批处理向量积。jax.numpy.vecmat(): 向量-矩阵乘积。jax.numpy.matmul():一般矩阵乘法。
示例
简单的矩阵-向量乘积
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([7, 8, 9]) >>> jnp.matvec(x1, x2) Array([ 50, 122], dtype=int32)
批量矩阵-向量乘积
>>> x2 = jnp.array([[7, 8, 9], ... [5, 6, 7]]) >>> jnp.matvec(x1, x2) Array([[ 50, 122], [ 38, 92]], dtype=int32)