公式很短,动机却很值得想清楚

先问一个具体的问题

给你一句话切成 5 个 token:[I, like, the, blue, sky]。现在你想让模型处理第 4 个 token blue,让它的内部表示带上整句话的上下文信息。

最直白的做法是把所有 5 个 token 的向量平均一下:

context = (h1 + h2 + h3 + h4 + h5) / 5

这能不能用?能用。但它有一个明显问题:blue 来说,sky 应该比 I 更重要——它们在语义上更紧密。把它们等权平均,等于丢失了"相关度"这个信号。

那把权重做成可学习的呢?

# 给每个位置训练一个固定权重
w = nn.Parameter(torch.randn(5))
context = (w[0]*h1 + w[1]*h2 + ... ) / w.sum()

也不行。这套权重一旦训练完就固定了。换一个句子,比如 [I, like, the, red, sky],模型用的还是同一套权重——但显然 redblue 在不同语境里"应该看哪个 token"是不同的。

我们真正需要的是:权重要根据当前 token 是什么、其他 token 是什么,动态决定

这就是 attention 的核心动机。

把动机翻译成公式

让每个 token 自己说三件事:

  • "我想找什么样的信息"(Query
  • "我能提供什么样的信息"(Key
  • "如果你来找我,我会给你这些内容"(Value

对位置 i 的 token:

  1. 用它的 Query 和所有其他 token 的 Key 算"匹配分数"——score[i, j] = Q_i · K_j
  2. 把这些分数过 softmax,得到一组和为 1 的权重 α[i, :]
  3. 用这组权重对所有 Value 加权求和——out_i = Σ α[i, j] * V_j

写成矩阵:

Attention(Q, K, V) = softmax(Q @ K.T / √d) @ V

/ √d 是为了让 score 的数值在不同 d(head 维度)下保持可控的方差——Q @ K.T 的方差大致正比于 d,开方后方差就稳了,softmax 不至于在大 d 时退化成 one-hot。

用 30 行 NumPy 跑一遍

完全不用框架,把 attention 算清楚:

import numpy as np

def softmax(x, axis=-1):
    x = x - x.max(axis=axis, keepdims=True)    # 数值稳定
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

def attention(Q, K, V):
    """Q, K, V shape: (T, d)"""
    d = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d)              # (T, T)
    weights = softmax(scores, axis=-1)         # (T, T) 每行和为 1
    return weights @ V, weights                # (T, d), (T, T)

# 一个能跑的最小例子
np.random.seed(0)
T, d = 5, 8
Q = np.random.randn(T, d)
K = np.random.randn(T, d)
V = np.random.randn(T, d)

out, weights = attention(Q, K, V)
print("attention weights (每行和为 1):")
print(np.round(weights, 2))
print("\n输出 shape:", out.shape)

输出大概像:

attention weights:
[[0.25 0.21 0.18 0.20 0.16]
 [0.13 0.34 0.22 0.18 0.13]
 [0.19 0.16 0.41 0.16 0.08]
 ...]

每行都是这个位置对所有位置的关注度分布。第 3 行第 3 列权重最高(0.41),符合直觉——一个 token 通常对自己有较高关注。

加上 causal mask 让它能做语言建模

上面写的是双向 attention,BERT 用的就是这种。生成式语言模型(GPT 系列)需要 因果 mask——位置 i 只能看到位置 [0, i],不能看到未来。

def causal_attention(Q, K, V):
    d = Q.shape[-1]
    T = Q.shape[0]
    scores = Q @ K.T / np.sqrt(d)
    # 把未来的位置设成 -inf,softmax 后就是 0
    mask = np.triu(np.ones((T, T)), k=1).astype(bool)
    scores[mask] = -np.inf
    weights = softmax(scores)
    return weights @ V, weights

out, weights = causal_attention(Q, K, V)
print(np.round(weights, 2))
[[1.00 0.00 0.00 0.00 0.00]    ← 第 0 步只能看自己
 [0.28 0.72 0.00 0.00 0.00]    ← 第 1 步看前 2 个
 [0.31 0.26 0.43 0.00 0.00]
 [0.24 0.20 0.21 0.35 0.00]
 [0.19 0.17 0.41 0.13 0.10]]

上三角全是 0,下三角每行和为 1。这就是 GPT 的 attention 形状。

切到 PyTorch,加上可学习参数

NumPy 版讲清楚了机制。生产里 Q/K/V 不是直接给的,而是用 三个独立 Linear 层从同一个输入 x 投影出来:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SingleHeadAttention(nn.Module):
    def __init__(self, dim, causal=True):
        super().__init__()
        self.W_q = nn.Linear(dim, dim, bias=False)
        self.W_k = nn.Linear(dim, dim, bias=False)
        self.W_v = nn.Linear(dim, dim, bias=False)
        self.W_o = nn.Linear(dim, dim, bias=False)
        self.causal = causal

    def forward(self, x):
        # x: (B, T, C)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(Q.size(-1))

        if self.causal:
            T = x.size(1)
            mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
            scores = scores.masked_fill(mask, float('-inf'))

        weights = F.softmax(scores, dim=-1)
        out = weights @ V
        return self.W_o(out)

四个 Linear 都是可学习的。模型在训练中自己决定"Query 应该提取什么特征""Key 应该暴露什么特征""Value 应该传递什么内容"——这是 attention 的全部表达能力。

最后那个 W_o 是输出投影,让 attention 输出再过一次线性变换。从纯数学角度它可以省略,但实际所有 Transformer 实现都保留——它给了多头融合后做混合的容量(下一篇 multi-head 会展开)。

把它训进一个 toy 任务,看权重学到了什么

光看代码不够直观。把上面这个 SingleHeadAttention 嵌进一个最小语言模型,用一段重复模式数据训几百步,再把 attention 权重画出来:

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# 数据:一段有强模式的人造文本("a后面跟b"的规律)
text = "ababababababab abc abc abc xyzxyzxyz" * 100
chars = sorted(set(text))
stoi = {c: i for i, c in enumerate(chars)}
data = torch.tensor([stoi[c] for c in text])
V_SIZE, DIM, CTX = len(chars), 32, 16

class ToyAttnLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok = nn.Embedding(V_SIZE, DIM)
        self.pos = nn.Embedding(CTX, DIM)
        self.attn = SingleHeadAttention(DIM)
        self.head = nn.Linear(DIM, V_SIZE)
    def forward(self, x, return_weights=False):
        T = x.size(1)
        h = self.tok(x) + self.pos(torch.arange(T))
        # 复用一下 attention 内部,多保存一份权重
        Q, K, V = self.attn.W_q(h), self.attn.W_k(h), self.attn.W_v(h)
        scores = (Q @ K.transpose(-2, -1)) / (DIM**0.5)
        mask = torch.triu(torch.ones(T, T), 1).bool()
        scores = scores.masked_fill(mask, float('-inf'))
        weights = F.softmax(scores, -1)
        out = self.attn.W_o(weights @ V)
        logits = self.head(out)
        return (logits, weights) if return_weights else logits

# 训练
model = ToyAttnLM()
opt = torch.optim.Adam(model.parameters(), lr=3e-3)
def get_batch():
    idx = torch.randint(0, len(data) - CTX - 1, (32,))
    x = torch.stack([data[i:i+CTX] for i in idx])
    y = torch.stack([data[i+1:i+CTX+1] for i in idx])
    return x, y

for step in range(800):
    x, y = get_batch()
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, V_SIZE), y.view(-1))
    opt.zero_grad(); loss.backward(); opt.step()
    if step % 100 == 0:
        print(f"step {step}: loss={loss.item():.3f}")

# 取一段文本看权重
sample = "abababab"
x = torch.tensor([[stoi[c] for c in sample]])
with torch.no_grad():
    _, w = model(x, return_weights=True)

print("\nattention weights (行 i 看到了哪些列 j):")
print(torch.round(w[0] * 100).int())   # 百分制看着方便

典型输出会显示某些位置很清楚地学会了"看前一个 token"——因为 a 总是被 b 跟着、b 总是被 a 跟着,模型只要学会 attn[i, i-1] ≈ 1 就能很好预测。

这就是 attention 最直接的解释:它学的是一组动态的"我该注意哪里"的规则,规则的具体形式由数据决定。

拆开看看不同位置的输出会怎样

attention 之所以强大,不是因为它有什么神秘力量,而是因为它让模型有了灵活的信息传递机制。下面几个反直觉实验帮你建立这种感觉。

实验 1:把所有 Query 都换成同一个向量。

Q_all = Q[0:1].expand_as(Q)   # 每个位置都用第 0 个的 Q
scores = (Q_all @ K.transpose(-2, -1)) / math.sqrt(d)

结果:所有位置的 attention 权重变成同一行,每个位置输出都看同样的上下文混合——模型退化成"只有一个查询的检索器"。Query 提供差异化的信号

实验 2:把 V 全部清零。

out = weights @ torch.zeros_like(V)   # 全 0

不管 attention 权重多漂亮,输出全是 0。V 才是真正传递的内容,attention 决定的是怎么混合,不决定混合什么。

实验 3:把所有 K 设成一样。

K_const = K[0:1].expand_as(K)

Q @ K.T 每一列都一样,softmax 后所有位置的权重都是均匀分布——attention 退化成普通平均池化,跟开头我们觉得"不够用"的方案一样。K 决定每个位置如何"被找到"

跑一遍这三个实验,Q/K/V 各自的角色就不再抽象了。

一个常被忽略的细节:dropout 加在哪

工业实现的 attention 通常在两个地方加 dropout:

# 1) 加在 attention 权重上(attention dropout)
weights = F.dropout(weights, p=0.1, training=self.training)
out = weights @ V

# 2) 加在最后输出(residual dropout)
out = F.dropout(self.W_o(out), p=0.1, training=self.training)

第一个是为了让模型不过分依赖某几个固定位置;第二个是 Transformer block 标配。早期 GPT 都有,但新一代大模型(Llama、Qwen)多数把 attention dropout 设为 0——它们规模够大、数据够多,dropout 反而拖慢收敛。

我对 attention 的一个直觉

写完这些代码再看 softmax(QK/√d)V 这个公式,会发现它其实在做一件很简单的事:

让每个位置自己定义"我要什么"(Q),让每个位置告诉别人"我有什么"(K, V),然后把它们按"匹配度"组合起来。

公式短到一行,但表达力极强——任意位置可以瞬间联系任意其他位置,不受距离限制(这是 RNN 做不到的);权重完全由数据决定,没有任何预设规则(这是 CNN 的固定 kernel 做不到的)。

后面所有的扩展——multi-head(多组 Q/K/V 并行)、causal mask(限制只看过去)、KV cache(缓存历史 K/V 不重算)、GQA(让多个 Q 共享 K/V 省显存)、FlashAttention(不显式构造 T×T 矩阵)——都是在这个基础上做工程优化。核心机制就这一行公式

下一篇我们把单 head 扩展成 multi-head,看为什么"多个并行的小 attention"比"一个大的 attention"表达力更强。

延伸阅读

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:项目 04:手写 scaled dot-product attention

本文链接:https://www.sshipanoo.com/blog/ai/llm-roadmap/项目04-attention/

本文最后一次更新为 天前,文章中的某些内容可能已过时!