jax.numpy.vecdot#

jax.numpy.vecdot(x1, x2, /, *, axis=-1, precision=None, preferred_element_type=None)[source]#

执行两个批处理向量的共轭乘法。

numpy.vecdot() 的 JAX 实现。

参数:
返回:

数组,包含 ab 沿 axis 的共轭点积。非收缩维度会被广播在一起。

返回类型:

Array

参见

示例

两个 1D 数组的向量共轭点积

>>> a = jnp.array([1j, 2j, 3j])
>>> b = jnp.array([4., 5., 6.])
>>> jnp.linalg.vecdot(a, b)
Array(0.-32.j, dtype=complex64)

两个 2D 数组的批处理向量点积

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[2, 3, 4]])
>>> jnp.linalg.vecdot(a, b, axis=-1)
Array([20, 47], dtype=int32)