jax.numpy.vdot#
- jax.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)[源代码]#
执行两个 1D 向量的共轭乘法。
JAX 实现的
numpy.vdot()
。- 参数:
a (Array | ndarray | bool | number | bool | int | float | complex) – 第一个输入数组,如果不是 1D,它将被展平。
b (Array | ndarray | bool | number | bool | int | float | complex) – 第二个输入数组,如果不是 1D,它将被展平。必须有
a.size == b.size
。precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – 要么是
None
(默认值),这意味着后端的默认精度,要么是一个Precision
枚举值 (Precision.DEFAULT
,Precision.HIGH
或Precision.HIGHEST
),要么是一个由两个此类值组成的元组,指示a
和b
的精度。preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – 要么是
None
(默认值),这意味着输入类型的默认累积类型,要么是数据类型,指示累积结果并返回具有该数据类型的结果。
- 返回:
包含输入共轭向量积的标量数组(形状为
()
)。- 返回类型:
另请参阅
jax.numpy.vecdot()
:批量向量积。jax.numpy.matmul()
:一般矩阵乘法。jax.lax.dot_general()
:一般的 N 维批量点积。
示例
>>> x = jnp.array([1j, 2j, 3j]) >>> y = jnp.array([1., 2., 3.]) >>> jnp.vdot(x, y) Array(0.-14.j, dtype=complex64)
请注意此函数与
dot()
之间的区别,后者在复杂的情况下不会对第一个输入进行共轭处理>>> jnp.dot(x, y) Array(0.+14.j, dtype=complex64)