公式很短,动机却很值得想清楚
先问一个具体的问题
给你一句话切成 5 个 token:[I, like, the, blue, sky]。现在你想让模型处理第 4 个 token blue,让它的内部表示带上整句话的上下文信息。
最直白的做法是把所有 5 个 token 的向量平均一下:
context = (h1 + h2 + h3 + h4 + h5) / 5
这能不能用?能用。但它有一个明显问题:对 blue 来说,sky 应该比 I 更重要——它们在语义上更紧密。把它们等权平均,等于丢失了"相关度"这个信号。
那把权重做成可学习的呢?
# 给每个位置训练一个固定权重
w = nn.Parameter(torch.randn(5))
context = (w[0]*h1 + w[1]*h2 + ... ) / w.sum()
也不行。这套权重一旦训练完就固定了。换一个句子,比如 [I, like, the, red, sky],模型用的还是同一套权重——但显然 red 和 blue 在不同语境里"应该看哪个 token"是不同的。
我们真正需要的是:权重要根据当前 token 是什么、其他 token 是什么,动态决定。
这就是 attention 的核心动机。
把动机翻译成公式
让每个 token 自己说三件事:
- "我想找什么样的信息"(Query)
- "我能提供什么样的信息"(Key)
- "如果你来找我,我会给你这些内容"(Value)
对位置 i 的 token:
- 用它的 Query 和所有其他 token 的 Key 算"匹配分数"——
score[i, j] = Q_i · K_j - 把这些分数过 softmax,得到一组和为 1 的权重
α[i, :] - 用这组权重对所有 Value 加权求和——
out_i = Σ α[i, j] * V_j
写成矩阵:
Attention(Q, K, V) = softmax(Q @ K.T / √d) @ V
/ √d 是为了让 score 的数值在不同 d(head 维度)下保持可控的方差——Q @ K.T 的方差大致正比于 d,开方后方差就稳了,softmax 不至于在大 d 时退化成 one-hot。
用 30 行 NumPy 跑一遍
完全不用框架,把 attention 算清楚:
import numpy as np
def softmax(x, axis=-1):
x = x - x.max(axis=axis, keepdims=True) # 数值稳定
e = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)
def attention(Q, K, V):
"""Q, K, V shape: (T, d)"""
d = Q.shape[-1]
scores = Q @ K.T / np.sqrt(d) # (T, T)
weights = softmax(scores, axis=-1) # (T, T) 每行和为 1
return weights @ V, weights # (T, d), (T, T)
# 一个能跑的最小例子
np.random.seed(0)
T, d = 5, 8
Q = np.random.randn(T, d)
K = np.random.randn(T, d)
V = np.random.randn(T, d)
out, weights = attention(Q, K, V)
print("attention weights (每行和为 1):")
print(np.round(weights, 2))
print("\n输出 shape:", out.shape)
输出大概像:
attention weights:
[[0.25 0.21 0.18 0.20 0.16]
[0.13 0.34 0.22 0.18 0.13]
[0.19 0.16 0.41 0.16 0.08]
...]
每行都是这个位置对所有位置的关注度分布。第 3 行第 3 列权重最高(0.41),符合直觉——一个 token 通常对自己有较高关注。
加上 causal mask 让它能做语言建模
上面写的是双向 attention,BERT 用的就是这种。生成式语言模型(GPT 系列)需要 因果 mask——位置 i 只能看到位置 [0, i],不能看到未来。
def causal_attention(Q, K, V):
d = Q.shape[-1]
T = Q.shape[0]
scores = Q @ K.T / np.sqrt(d)
# 把未来的位置设成 -inf,softmax 后就是 0
mask = np.triu(np.ones((T, T)), k=1).astype(bool)
scores[mask] = -np.inf
weights = softmax(scores)
return weights @ V, weights
out, weights = causal_attention(Q, K, V)
print(np.round(weights, 2))
[[1.00 0.00 0.00 0.00 0.00] ← 第 0 步只能看自己
[0.28 0.72 0.00 0.00 0.00] ← 第 1 步看前 2 个
[0.31 0.26 0.43 0.00 0.00]
[0.24 0.20 0.21 0.35 0.00]
[0.19 0.17 0.41 0.13 0.10]]
上三角全是 0,下三角每行和为 1。这就是 GPT 的 attention 形状。
切到 PyTorch,加上可学习参数
NumPy 版讲清楚了机制。生产里 Q/K/V 不是直接给的,而是用 三个独立 Linear 层从同一个输入 x 投影出来:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SingleHeadAttention(nn.Module):
def __init__(self, dim, causal=True):
super().__init__()
self.W_q = nn.Linear(dim, dim, bias=False)
self.W_k = nn.Linear(dim, dim, bias=False)
self.W_v = nn.Linear(dim, dim, bias=False)
self.W_o = nn.Linear(dim, dim, bias=False)
self.causal = causal
def forward(self, x):
# x: (B, T, C)
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
if self.causal:
T = x.size(1)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
weights = F.softmax(scores, dim=-1)
out = weights @ V
return self.W_o(out)
四个 Linear 都是可学习的。模型在训练中自己决定"Query 应该提取什么特征""Key 应该暴露什么特征""Value 应该传递什么内容"——这是 attention 的全部表达能力。
最后那个 W_o 是输出投影,让 attention 输出再过一次线性变换。从纯数学角度它可以省略,但实际所有 Transformer 实现都保留——它给了多头融合后做混合的容量(下一篇 multi-head 会展开)。
把它训进一个 toy 任务,看权重学到了什么
光看代码不够直观。把上面这个 SingleHeadAttention 嵌进一个最小语言模型,用一段重复模式数据训几百步,再把 attention 权重画出来:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
# 数据:一段有强模式的人造文本("a后面跟b"的规律)
text = "ababababababab abc abc abc xyzxyzxyz" * 100
chars = sorted(set(text))
stoi = {c: i for i, c in enumerate(chars)}
data = torch.tensor([stoi[c] for c in text])
V_SIZE, DIM, CTX = len(chars), 32, 16
class ToyAttnLM(nn.Module):
def __init__(self):
super().__init__()
self.tok = nn.Embedding(V_SIZE, DIM)
self.pos = nn.Embedding(CTX, DIM)
self.attn = SingleHeadAttention(DIM)
self.head = nn.Linear(DIM, V_SIZE)
def forward(self, x, return_weights=False):
T = x.size(1)
h = self.tok(x) + self.pos(torch.arange(T))
# 复用一下 attention 内部,多保存一份权重
Q, K, V = self.attn.W_q(h), self.attn.W_k(h), self.attn.W_v(h)
scores = (Q @ K.transpose(-2, -1)) / (DIM**0.5)
mask = torch.triu(torch.ones(T, T), 1).bool()
scores = scores.masked_fill(mask, float('-inf'))
weights = F.softmax(scores, -1)
out = self.attn.W_o(weights @ V)
logits = self.head(out)
return (logits, weights) if return_weights else logits
# 训练
model = ToyAttnLM()
opt = torch.optim.Adam(model.parameters(), lr=3e-3)
def get_batch():
idx = torch.randint(0, len(data) - CTX - 1, (32,))
x = torch.stack([data[i:i+CTX] for i in idx])
y = torch.stack([data[i+1:i+CTX+1] for i in idx])
return x, y
for step in range(800):
x, y = get_batch()
logits = model(x)
loss = F.cross_entropy(logits.view(-1, V_SIZE), y.view(-1))
opt.zero_grad(); loss.backward(); opt.step()
if step % 100 == 0:
print(f"step {step}: loss={loss.item():.3f}")
# 取一段文本看权重
sample = "abababab"
x = torch.tensor([[stoi[c] for c in sample]])
with torch.no_grad():
_, w = model(x, return_weights=True)
print("\nattention weights (行 i 看到了哪些列 j):")
print(torch.round(w[0] * 100).int()) # 百分制看着方便
典型输出会显示某些位置很清楚地学会了"看前一个 token"——因为 a 总是被 b 跟着、b 总是被 a 跟着,模型只要学会 attn[i, i-1] ≈ 1 就能很好预测。
这就是 attention 最直接的解释:它学的是一组动态的"我该注意哪里"的规则,规则的具体形式由数据决定。
拆开看看不同位置的输出会怎样
attention 之所以强大,不是因为它有什么神秘力量,而是因为它让模型有了灵活的信息传递机制。下面几个反直觉实验帮你建立这种感觉。
实验 1:把所有 Query 都换成同一个向量。
Q_all = Q[0:1].expand_as(Q) # 每个位置都用第 0 个的 Q
scores = (Q_all @ K.transpose(-2, -1)) / math.sqrt(d)
结果:所有位置的 attention 权重变成同一行,每个位置输出都看同样的上下文混合——模型退化成"只有一个查询的检索器"。Query 提供差异化的信号。
实验 2:把 V 全部清零。
out = weights @ torch.zeros_like(V) # 全 0
不管 attention 权重多漂亮,输出全是 0。V 才是真正传递的内容,attention 决定的是怎么混合,不决定混合什么。
实验 3:把所有 K 设成一样。
K_const = K[0:1].expand_as(K)
Q @ K.T 每一列都一样,softmax 后所有位置的权重都是均匀分布——attention 退化成普通平均池化,跟开头我们觉得"不够用"的方案一样。K 决定每个位置如何"被找到"。
跑一遍这三个实验,Q/K/V 各自的角色就不再抽象了。
一个常被忽略的细节:dropout 加在哪
工业实现的 attention 通常在两个地方加 dropout:
# 1) 加在 attention 权重上(attention dropout)
weights = F.dropout(weights, p=0.1, training=self.training)
out = weights @ V
# 2) 加在最后输出(residual dropout)
out = F.dropout(self.W_o(out), p=0.1, training=self.training)
第一个是为了让模型不过分依赖某几个固定位置;第二个是 Transformer block 标配。早期 GPT 都有,但新一代大模型(Llama、Qwen)多数把 attention dropout 设为 0——它们规模够大、数据够多,dropout 反而拖慢收敛。
我对 attention 的一个直觉
写完这些代码再看 softmax(QK/√d)V 这个公式,会发现它其实在做一件很简单的事:
让每个位置自己定义"我要什么"(Q),让每个位置告诉别人"我有什么"(K, V),然后把它们按"匹配度"组合起来。
公式短到一行,但表达力极强——任意位置可以瞬间联系任意其他位置,不受距离限制(这是 RNN 做不到的);权重完全由数据决定,没有任何预设规则(这是 CNN 的固定 kernel 做不到的)。
后面所有的扩展——multi-head(多组 Q/K/V 并行)、causal mask(限制只看过去)、KV cache(缓存历史 K/V 不重算)、GQA(让多个 Q 共享 K/V 省显存)、FlashAttention(不显式构造 T×T 矩阵)——都是在这个基础上做工程优化。核心机制就这一行公式。
下一篇我们把单 head 扩展成 multi-head,看为什么"多个并行的小 attention"比"一个大的 attention"表达力更强。
延伸阅读
- Attention Is All You Need——原始论文,第 3 节就是公式
- The Illustrated Transformer——经典图解,attention 部分图非常清晰
- Karpathy: Let's build GPT from scratch(视频)——一段视频里手把手从 bigram 推到 attention
- Attention? Attention!(Lilian Weng)——从 RNN attention 一路梳理到 self-attention 的历史脉络
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:项目 04:手写 scaled dot-product attention
本文链接:https://www.sshipanoo.com/blog/ai/llm-roadmap/项目04-attention/
本文最后一次更新为 天前,文章中的某些内容可能已过时!