jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, *, out_sharding=None)[源代码]#
通用点积/缩并运算符。
包装了 XLA 的 DotGeneral 运算符。
dot_general
的语义很复杂,但大多数用户不应直接使用它。相反,您可以使用更高级别的函数,如jax.numpy.dot()
,jax.numpy.matmul()
,jax.numpy.tensordot()
,jax.numpy.einsum()
等,它们将在后台构建对dot_general
的适当调用。如果您真的想了解dot_general
本身,我们建议阅读 XLA 的 DotGeneral 运算符文档。- 参数:
lhs (ArrayLike) – 一个数组
rhs (ArrayLike) – 一个数组
dimension_numbers (DotDimensionNumbers) – 一个元组的元组,其中包含整数序列,形式为
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
precision (PrecisionLike | None) –
可选。此参数控制计算的数值精度,可以是以下之一
None
,表示当前后端的默认精度,一个
DotAlgorithm
或一个DotAlgorithmPreset
,指示必须用于累积点积的算法。
preferred_element_type (DTypeLike | None | None) – 可选。此参数控制点积输出的数据类型。默认情况下,此操作的输出元素类型将匹配
lhs
和rhs
输入元素类型,并遵循通常的类型提升规则。将preferred_element_type
设置为特定的dtype
将意味着该操作返回该元素类型。当precision
不是DotAlgorithm
或DotAlgorithmPreset
时,preferred_element_type
为编译器提供了一个提示,以使用此数据类型累积点积。
- 返回:
一个数组,其第一个维度是(共享的)批处理维度,其次是
lhs
非缩并/非批处理维度,最后是rhs
非缩并/非批处理维度。- 返回类型: