jax.nn.dot_product_attention#
- jax.nn.dot_product_attention(query, key, value, bias=None, mask=None, *, scale=None, is_causal=False, query_seq_lengths=None, key_value_seq_lengths=None, local_window_size=None, implementation=None)[源代码]#
缩放点积注意力函数。
在 Query、Key 和 Value 张量上计算注意力函数
\[\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V\]如果我们定义
logits为 \(QK^T\) 的输出,probs为 \(softmax\) 的输出。在此函数中,我们使用以下大写字母表示数组的形状
B = batch size S = length of the key/value (source) T = length of the query (target) N = number of attention heads H = dimensions of each attention head K = number of key/value heads G = number of groups, which equals to N // K
- 参数:
query (ArrayLike) – 查询数组;形状为
(BTNH|TNH)key (ArrayLike) – 键数组:形状为
(BSKH|SKH)。当 K 等于 N 时,执行多头注意力 (MHA https://arxiv.org/abs/1706.03762)。否则,如果 N 是 K 的倍数,则执行分组查询注意力 (GQA https://arxiv.org/abs/2305.13245);如果 K == 1 (GQA 的特例),则执行多查询注意力 (MQA https://arxiv.org/abs/1911.02150)。value (ArrayLike) – 值数组,其形状应与 key 数组相同。
bias (ArrayLike | None) – 可选,要添加到 logits 的偏置数组;形状必须是 4D 并且可以广播到
(BNTS|NTS)。mask (ArrayLike | None) – 可选,用于过滤 logits 的掩码数组。它是一个布尔掩码,其中 True 表示该元素应参与注意力。对于加性掩码,用户应将其传递给 bias。形状必须是 4D 并且可以广播到
(BNTS|NTS)。scale (float | None) – logits 的缩放因子。如果为 None,则缩放因子将设置为 1 除以查询的头维度 (即 H) 的平方根。
is_causal (bool) – 如果为 True,则应用因果注意力。注意,一些实现,如 xla,将生成一个掩码张量并将其应用于 logits 以屏蔽注意力矩阵的非因果部分;而其他实现,如 cudnn,将避免计算非因果区域,从而提供加速。
query_seq_lengths (ArrayLike | None) – 查询的序列长度的 int32 数组;形状为
(B)key_value_seq_lengths (ArrayLike | None) – 键和值的序列长度的 int32 数组;形状为
(B)local_window_size (int | tuple[int, int] | None) – 用于使自注意力关注每个 token 的局部窗口的窗口大小。如果设置,这将指定每个 token 的 (左窗口大小, 右窗口大小)。例如,如果 local_window_size == (3, 2) 且序列为 [0, 1, 2, 3, 4, 5, c, 7, 8, 9],则 token c 可以关注 [3, 4, 5, c, 7, 8]。如果给定单个整数,它将被解释为对称窗口 (window_size, window_size)。
implementation (Literal['xla', 'cudnn'] | None) – 用于控制使用哪个实现后端的字符串。支持的字符串包括 xla、cudnn (cuDNN flash attention)。默认为 None,目前回退到 xla。注意,cudnn 只支持形状/数据类型的子集,如果不支持,将抛出异常。
- 返回:
与
query形状相同的注意力输出数组。- 返回类型: