并行的不是一种关系,而是多组关系
在开始之前:把这一篇要用到的词先讲清楚
为了让读这篇时不用切窗口去查别的资料,把后面会反复出现的术语先放在这里:
- token:模型看到的最小单位。中文里一个 token 经常对应一个字或一两个字,英文里经常对应一个常用单词或常见前缀/后缀。怎么切是上一个项目
lab-01-tokenizer讲的 - embedding:每个 token 在模型内部对应的一个向量(比如长度 256 的一串实数)。模型"看到" token 实际上看到的是这串数字
- attention:让每个 token 在生成自己的表示时,能"看到"其他 token 并按相关度做加权聚合的机制。项目 04 详细写过
- Q / K / V(query / key / value):attention 的三个内部向量。直觉上 Q 表示"我在找什么",K 表示"我能被怎么找到",V 表示"如果你找到我,我给你什么内容"
- head:在 attention 内部并行的一组独立计算单元。一个 head 算一组 Q/K/V,得到一组注意力权重
- softmax:把一组任意实数变成一组和为 1 的非负数(概率分布)的函数:
softmax(x_i) = exp(x_i) / Σ exp(x_j)。值越大的输入对应的输出概率越高 - dim / hidden size:模型每一层每个 token 的向量长度。常见 256 / 512 / 768 / 4096 等
- perplexity(困惑度):衡量语言模型预测下一个 token 时"有多不确定"的指标,越小越好。它是 cross entropy loss 的指数:
perplexity = exp(loss) - LayerNorm / RMSNorm:在每一层之间把向量归一化(让数值大小可控),防止训练时数值爆炸或消失。RMSNorm 是 LayerNorm 的简化版,去掉了减均值步骤
- FLOP(floating-point operation):浮点运算次数,衡量计算量的常用单位。1 GFLOP = 10 亿次浮点运算
单头到底差在哪
项目 04 写完单头 attention 之后,可以做一个非常直接的对比实验:训练同一个 toy 语言模型(toy 这里指为了教学用的小规模模型,可能就几百万参数、几分钟能训完,不追求生产质量),把 dim 固定(比如 256),让单头 attention 跟 "4 个 head × 每 head 64 维" 的多头跑同样数据,比 perplexity。在任何稍微有点结构的语料上,多头会明显更低。
为什么?单头 attention 的权重矩阵 A = softmax(QK^T / √d) 是一个形状为 (T, T) 的矩阵(T 是序列长度,也就是有多少个 token)。每一行是一个 softmax 分布,给出"位置 i 应该多大权重看位置 j"的概率。关键限制:每对位置 (i, j) 只有一个标量权重。
但语言里一个 token 跟前文同时有多种关系:
- 句法依赖(syntactic dependency):句子里词与词之间的语法连接。例如 "the" 应该看后面的名词,因为定冠词修饰名词
- 共指(coreference):代词指向哪个名词。例如 "John lost his keys, he can't drive" 里的 "he" 共指 "John"
- 局部依赖:相邻几个 token 之间的紧密联系,例如固定搭配
- 长程语义:跨越很多 token 的主题呼应,例如段尾呼应段首
单头要把这几种关系压在同一个标量权重里——要么妥协(每种关系都学一点但都不深入),要么训练过程中轮流学一种,被新的覆盖掉。多头的解法很直接:给每对位置算多个独立的注意力分布,每个分布对应一种"视角"。h 个 head 就是 h 个独立 attention 在同一个 token 序列上并行跑,每个 head 有自己的 Q/K/V、自己的权重分布、自己的输出。
这不是"做 h 次单头然后平均"——每个 head 的 Q/K/V 是从输入投影出来的不同子空间(子空间 指原向量空间的一部分维度,比如 256 维空间切成 4 个 64 维子空间)。子空间不同,看到的东西就不同。
为什么要除 √d——一段非看不可的方差推导
绝大多数教程跳过 / √d 的推导,但理解它之后你会知道为什么 head_dim 不能开太大。
先说一个统计事实:如果一个随机变量 X 服从均值 0、方差 1 的分布(这叫标准正态分布或类似的分布),两个独立的 X 和 Y 相乘 X·Y,结果的方差也是 1。如果你把 d 个独立的乘积加起来:
S = X_1·Y_1 + X_2·Y_2 + ... + X_d·Y_d
Var(S) = d × Var(X·Y) = d
Std(S) = √d
这是问题的根源。
在 attention 里,Q 和 K 经过 LayerNorm(归一化层,把每个向量调成均值 0、方差 1)之后,它们的每个元素大致服从均值 0、方差 1 的分布。点积 q·k = Σ q_i × k_i 正好就是 d 项独立乘积之和,所以 Std(q·k) ≈ √d。
当 d 大(例如 d=128)时,√d ≈ 11——一些 score 会大到 ±11 量级。
把 ±11 量级的 score 送进 softmax 会发生什么?
import math
softmax_max = math.exp(11) / (math.exp(11) + math.exp(-11) + ...)
≈ 1.0 # 输入 11 的位置占绝对优势
其他位置 ≈ 0.0
整个分布退化成 one-hot(one-hot 向量 指只有一个位置是 1、其余全是 0 的向量)。一旦 attention 权重接近 one-hot,模型只在看一个位置,梯度(梯度指反向传播时用来更新参数的方向信号)只通过这一个位置回传,模型基本学不动。
解决方法是把 score 除以 √d,让它的标准差回到 1 量级:
score' = (q · k) / √d → Std(score') ≈ 1
softmax 在 ±1 量级的输入上是平滑的,梯度能通过所有位置正常回传。这就是 √d 这个看似魔法常数的真正来源——它把方差从 d 拉回 1。
这个推导对多头 attention 也成立——每个 head 用的是 head_dim 而不是总 dim:
scores = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim) # 不是 sqrt(dim)
如果错写成 / sqrt(dim),多头时 head_dim < dim,每个 head 的 score 反而被压得过小(标准差小于 1),softmax 分布太平,attention 不集中。这是 MHA 实现里最常见的 bug 之一。
reshape 让多个 head 并行起来
工程实现的核心 trick 是不用 for 循环,而是把 head 维 reshape 到 batch 之后并行。
先解释几个 PyTorch 操作:
tensor.view(...):在不复制数据的前提下,把张量看成另一种形状。比如(2, 6)可以 view 成(2, 3, 2)或(12,)tensor.transpose(a, b):交换两个维度。形状(B, T, H, D)transpose(1, 2) 之后变(B, H, T, D)tensor.contiguous():transpose 之后内存布局不再连续,必须调用 contiguous 把它复制成连续布局,否则后续 view 会报错
设 dim = 256、n_head = 4,那么每个 head 的 head_dim = 64。完整的形状变换流:
x: (B, T, 256)
W_q(x): (B, T, 256) ← 一次 Linear,包含所有 4 个 head 的 Q
.view(B, T, 4, 64): (B, T, 4, 64) ← 把 256 维切成 4 份 × 64 维
.transpose(1, 2): (B, 4, T, 64) ← 把 head 维转到第 1 位,变成 batch-like
scores = q @ k.T: (B, 4, T, T) ← 前两维 (B, H) 被矩阵乘当成批量并行
softmax(scores): (B, 4, T, T)
out = scores @ v: (B, 4, T, 64)
.transpose(1, 2): (B, T, 4, 64) ← 把 head 维转回来
.contiguous().view: (B, T, 256) ← 4 个 head 的输出拼接回 256 维
W_o(out): (B, T, 256) ← 输出投影
关键点:PyTorch 的矩阵乘 @ 把最后两个维度当矩阵,前面的维度都当 batch 并行。(B, 4, T, 64) @ (B, 4, 64, T) 直接给你 (B, 4, T, T),4 个 head 一次性算完。
完整最小实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, dim, n_head, causal=True,
attn_dropout=0.0, resid_dropout=0.0):
"""
dim: 总隐藏维度,例如 256
n_head: head 数量,dim 必须能被它整除
causal: True 表示 decoder 风格(每个位置只看自己 + 历史)
attn_dropout: 作用在 attention 权重上的 dropout(防过拟合)
resid_dropout: 作用在输出上的 dropout
"""
super().__init__()
assert dim % n_head == 0, f"dim={dim} 必须能被 n_head={n_head} 整除"
self.dim = dim
self.n_head = n_head
self.head_dim = dim // n_head
# 三个 Linear 一次性算出所有 head 的 Q/K/V,bias=False 是现代实现的默认
# 因为后面有 LayerNorm/RMSNorm,bias 会被吃掉所以省去
self.W_q = nn.Linear(dim, dim, bias=False)
self.W_k = nn.Linear(dim, dim, bias=False)
self.W_v = nn.Linear(dim, dim, bias=False)
# 输出投影:让不同 head 的输出有互相组合的机会
self.W_o = nn.Linear(dim, dim, bias=False)
self.attn_drop = nn.Dropout(attn_dropout)
self.resid_drop = nn.Dropout(resid_dropout)
self.causal = causal
def forward(self, x, return_attn=False):
# x: (Batch, Time/序列长度, Channel/向量维度)
B, T, C = x.shape
H, D = self.n_head, self.head_dim
# 1. 投影 + reshape + 转置:把 (B, T, C) 变成 (B, H, T, D)
# H 在第 1 维,后面矩阵乘会把 (B, H) 当批量
q = self.W_q(x).view(B, T, H, D).transpose(1, 2)
k = self.W_k(x).view(B, T, H, D).transpose(1, 2)
v = self.W_v(x).view(B, T, H, D).transpose(1, 2)
# 2. 算 attention scores:(B, H, T, D) @ (B, H, D, T) = (B, H, T, T)
# 注意是 sqrt(D),不是 sqrt(C)!见上一节方差推导
scores = (q @ k.transpose(-2, -1)) / math.sqrt(D)
# 3. 因果 mask:把未来位置的 score 设为 -inf
# softmax(-inf) = 0,未来位置就完全看不到
if self.causal:
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
# 4. softmax + dropout + 加权求和
attn = F.softmax(scores, dim=-1)
attn = self.attn_drop(attn)
out = attn @ v # (B, H, T, D)
# 5. 把 head 维转回去,拼回 (B, T, C)
out = out.transpose(1, 2).contiguous().view(B, T, C)
# 6. 输出投影 + dropout
out = self.resid_drop(self.W_o(out))
if return_attn:
return out, attn # attn 用于可视化
return out
实现里有 5 个新手常踩的坑:
W_q的输出维度是dim而不是head_dim——256 实际上是"4 个 head 各 64 维的 Q"拼起来的,view 一下就切开了。这避免了写 4 个独立 Linear,参数总数也不变transpose(1, 2)不可省——只有把 head 维提到 batch 维之后,矩阵乘才会按 head 独立算.contiguous()在 transpose 之后必须加一次,否则 view 会报view size is not compatible with input tensor's size and stride- scale 用
sqrt(head_dim)不是sqrt(dim)——见上一节方差推导 W_o不能省——见后面"head 之间怎么通信"那一节
参数和 FLOP 怎么算
很多人第一次看到 MHA 会担心"开 8 个 head 是不是参数 8 倍"。其实不是,参数量跟 head 数完全无关。
设 dim = 512,无论 n_head 是 1、4、8 还是 16,参数都是:
| 组件 | shape | 参数量 |
|---|---|---|
| W_q | (dim, dim) | dim² = 262144 |
| W_k | (dim, dim) | dim² = 262144 |
| W_v | (dim, dim) | dim² = 262144 |
| W_o | (dim, dim) | dim² = 262144 |
| 总计 | 4 × dim² ≈ 1M |
多头只是把同一组参数"切分"成多个并行子空间使用——总维度没变,只是从 1 个 512 维空间变成 8 个 64 维空间。
FLOP(前面解释过:浮点运算次数,衡量计算量)也可以精确算。对一次 forward(batch=1,seq=T):
| 步骤 | FLOP |
|---|---|
三个 Q/K/V 投影 Linear(dim, dim) 各做 T 次 | 3 × T × dim² |
Q @ K^T:把 (T, dim) 和 (dim, T) 相乘 | T × T × dim = T² × dim |
attn @ V:把 (T, T) 和 (T, dim) 相乘 | T² × dim |
输出投影 W_o(out) | T × dim² |
| 总计 | 4 × T × dim² + 2 × T² × dim |
观察两个结论:
- T 短时(T < dim),FLOP 主导是
4 × T × dim²,跟 T 是线性关系 - T 长时(T > dim),FLOP 主导是
2 × T² × dim,跟 T 是平方关系——这就是 attention 长上下文 O(T²) 复杂度的来源,也是 FlashAttention(项目 15)要解决的问题 - n_head 不出现在公式里——MHA 跟单头总 FLOP 完全一致
每个 head 的 head_dim 变小是有代价的(256 → 32)。head_dim 太小时单 head 的表达力会下降。所以 head 数不是越多越好,要跟 dim 一起平衡。看几个真实模型的配置:
| 模型 | dim | n_head | head_dim |
|---|---|---|---|
| GPT-2 small | 768 | 12 | 64 |
| LLaMA 7B | 4096 | 32 | 128 |
| LLaMA 70B | 8192 | 64 | 128 |
| GPT-3 175B | 12288 | 96 | 128 |
head_dim = 64 或 128 是经验上的稳定值——再小(< 32)单 head 容量不够;再大(> 256)多头收益变薄。
不同 head 真的学到不同东西吗——画出来看
光说不够,把 attention 权重画出来。seaborn 的 heatmap 函数可以把一个二维矩阵渲染成颜色块图——颜色深浅对应数值大小:
import matplotlib.pyplot as plt
import seaborn as sns
# 假设已经训好一个有 MHA 的 toy LM
model.eval()
sample = "the quick brown fox jumps over the lazy dog"
ids = torch.tensor([[stoi[c] for c in sample]]) # stoi: 字符到 ID 的字典
with torch.no_grad():
_, attn = model.mha(model.encode(ids), return_attn=True)
# attn shape: (1, n_head, T, T),每个 (T, T) 是一个 head 的注意力矩阵
n_head = attn.shape[1]
fig, axes = plt.subplots(1, n_head, figsize=(3*n_head, 3))
for h in range(n_head):
sns.heatmap(
attn[0, h].cpu().numpy(),
ax=axes[h],
cmap='Blues', # 颜色映射:越蓝表示数值越大
square=True,
cbar=False,
xticklabels=list(sample),
yticklabels=list(sample),
)
axes[h].set_title(f"Head {h}")
plt.tight_layout()
plt.savefig("mha_heads.png", dpi=120)
跑出来你会看到几种典型模式(横轴是"被看的位置 j",纵轴是"看的位置 i"):
- 对角线 head:颜色集中在
(i, i-1)或(i, i-2)附近,专门关注最近 1~2 个 token——做的是 n-gram 风格的局部依赖。n-gram 指连续 n 个 token 的组合,比如 bigram 是相邻两个 token - bos 主导 head(bos = beginning of sequence,序列开头那个特殊 token):所有位置都盯着开头那个 token——这是 attention sink 现象,大部分训练好的模型里都至少有 1~2 个 head 是这样。"sink" 字面是水槽,意思是"接收 attention 多余权重的吸收器"——具体原因学界还在研究,主流解释是模型需要一个"垃圾桶"放掉那些不知道放哪的注意力权重
- 散点 head:每个位置都集中看几个特定远处 token,可能在学共指、句法、长程依赖
- uniform head:权重接近均匀分布(每个位置都差不多)——这个 head 没在工作,是冗余的
最后一种现象不是个例。Michel 等人 2019 年的论文 Are Sixteen Heads Really Better than One? 系统验证过:训练好的 Transformer 里有 20%~40% 的 head 是冗余的——剪掉它们对下游性能影响 < 1%。后续 Voita 等人 2019 Analyzing Multi-Head Self-Attention 把 head 按功能分了 4 类(positional:关注固定位置;syntactic:关注语法相关的位置;rare-words:关注罕见词;其他:功能不明显),证明确实有一部分 head 可以被精确解释,另一部分主要是噪声。
head 之间怎么通信——拆 W_o 看清楚
多头的一个关键设计是 head 之间在 attention 内部完全独立——head 1 的输出不会进入 head 2 的计算。它们的唯一会面在最后的 W_o。
可以用线性代数视角看 W_o 的作用。设拼接后的向量 concat_out 是把 h 个 (T, head_dim) 拼起来的 (T, dim):
W_o.shape = (dim, dim)
W_o @ concat_out 写成分块矩阵乘:
= (W_o,1 | W_o,2 | ... | W_o,h) @ [out_1; out_2; ...; out_h]
= Σ_j W_o,j @ out_j
其中 W_o,j 是 W_o 切成 h 份后第 j 块,shape 是 (dim, head_dim)。也就是说,W_o 的每一行都是 h 个 head 输出向量的线性组合的系数。这正是不同 head 互相通信的唯一机制——如果没有 W_o,后续层只能看到 h 个互不相关的 head 输出拼起来的"长向量"。
实验 1:去掉 W_o
# forward 里把 self.W_o(out) 改成 out
return out
观察:验证集 loss 会显著上升(典型 +0.2~0.5)。多头退化成"h 个互不通气的小模型并联"。
实验 2:把所有 head 强制学一样
# 通过初始化和梯度 hook 让 W_q / W_k / W_v 的 h 个 head 切片永远相等
# 等价于"h 头但所有头算一样的东西"
观察:跟单头复制 h 份等价,loss 跟单头一致。证明多头的实际价值确实来自 head 之间的差异化。
实验 3:n_head 扫描,dim 固定 256
| n_head | head_dim | val loss |
|---|---|---|
| 1 | 256 | 3.21 |
| 2 | 128 | 3.05 |
| 4 | 64 | 2.91 |
| 8 | 32 | 2.94(接近最优) |
| 16 | 16 | 3.08(开始反弹) |
| 32 | 8 | 3.30(严重退化) |
存在最优 n_head,超过之后 head_dim 太小不够用反而变差。这印证了前面 head_dim 经验值 64~128 的结论。
Self-attention vs cross-attention
MHA 不只用在 self-attention 里。先解释这两个词:
- self-attention(自注意力):Q/K/V 都来自同一个输入序列。一个 token 在看同一句话里的其他 token
- cross-attention(交叉注意力):Q 来自一个序列,K/V 来自另一个序列。一个序列的 token 在看另一个序列的 token
原始 Transformer 论文用在机器翻译上,decoder 里有两个 attention:
# decoder 内部
x = self_attention(x) # Q/K/V 都来自 decoder 自己(看已生成的历史输出)
x = cross_attention(x, enc) # Q 来自 decoder,K/V 来自 encoder(看待翻译的原文)
代码层面差异极小——只是 K/V 的来源换成另一个张量:
def cross_attention_forward(self, x, context):
"""x: 来自 decoder, context: 来自 encoder"""
q = self.W_q(x) # Q 用 decoder 状态
k = self.W_k(context) # K/V 用 encoder 状态
v = self.W_v(context)
# 后面的 reshape + attention 一模一样
...
重要事实:现代主流大模型(GPT 系列、LLaMA、Qwen、DeepSeek)都是 decoder-only 架构——没有 encoder,也就没有 cross-attention,只有 causal self-attention。encoder-decoder 架构(T5、BART、mT5、原始 Transformer 翻译用法)才有 cross-attention。这影响到你后续读论文时怎么理解模型结构。
GQA / MQA——MHA 在推理时的最大问题
现在主流大模型(LLaMA 3、Qwen 2、Mistral 等)几乎都不用纯 MHA 了,而是用 GQA(Grouped Query Attention):Query head 数多,K/V head 数少,多个 Q 共享一组 K/V。
为什么要做这个改动?要先理解 KV cache:
KV cache——大模型生成文本是逐 token 进行的(先输出第 1 个 token,再输出第 2 个......)。生成第 N 个 token 时,attention 需要历史所有 token 的 K 和 V。如果每生成一个新 token 都重新算前面所有 K/V,O(T²) 的计算量会爆炸。所以工程上把已经算过的 K/V 缓存起来,每生成一个新 token 只算它自己的新 K/V,跟缓存拼一下用。这块缓存就是 KV cache。
KV cache 的显存占用怎么算?设:
L 模型层数
H_kv 每层的 K/V head 数
D 每个 head 的维度
T 上下文长度(已生成 + 待生成的 token 数)
batch 并发请求数(同时服务多少用户)
bytes 每个数值的字节数(fp16 = 2 字节)
总显存:
mem = 2 × L × H_kv × D × T × batch × bytes
(× 2 是因为 K 和 V 各存一份)
以 LLaMA 2 70B 为例(L=80, D=128),对比两种配置:
MHA 版本(H_kv = 64,全部 head 都各自的 K/V):
2 × 80 × 64 × 128 × 4096 × 1 × 2 ≈ 10.7 GB ← 单条请求 4K 上下文
GQA 版本(H_kv = 8,多个 Q 共享 K/V):
2 × 80 × 8 × 128 × 4096 × 1 × 2 ≈ 1.3 GB
单条请求就省 8 倍显存。如果同时服务 32 个用户(batch = 32):
| 配置 | KV cache 总显存 | 一张 H100(80 GB)能装下吗 |
|---|---|---|
| MHA | 343 GB | 装不下,要 5 张卡才够 |
| GQA | 43 GB | 单卡有余 |
这就是为什么所有大模型推理向 GQA 收敛。MQA(Multi-Query Attention)是更激进的版本——所有 Q head 共享 1 组 K/V,cache 最小但质量损失开始明显(一般掉 1~2% 任务准确率)。
GQA 的细节是项目 12 的主题(具体 H_q=32, H_kv=8 的实现 + 质量对比),这里只是让你知道:MHA 是教科书定义,但生产模型基本都换成了 GQA 的变体。
一个常被忽略的细节:pre-LN vs post-LN
把 MHA 放进 Transformer block 时,LayerNorm 的放法有两种:
# Post-LN(原始 Transformer 2017 论文)
x = LayerNorm(x + MHA(x))
x = LayerNorm(x + FFN(x))
# Pre-LN(GPT-2 / LLaMA / Qwen 都用)
x = x + MHA(LayerNorm(x))
x = x + FFN(LayerNorm(x))
先解释这两种结构里的 "残差连接"(residual connection):x + f(x) 这种把输入和输出加起来的写法。它最早来自 ResNet(2015),让深层网络更易训——梯度可以从顶层"抄近道"直接流到底层。
Post-LN 和 Pre-LN 的区别是 LayerNorm 放在残差之后还是之前。差异看起来微小,但实测差距很大:
- Post-LN 训练不稳定:早期 step 容易出现 NaN(无穷大或者无意义的数),需要 learning rate warmup 才能起步。warmup 指训练开始时先用很小的学习率,慢慢升到正常值,避免大梯度把模型一下子搞坏
- Pre-LN 训练稳定:可以直接用大学习率,warmup 也可以省
为什么?Pre-LN 的残差是 x + 模块(LN(x)),残差路径上没有 LN,梯度从顶层直接流到底层不衰减。Post-LN 的残差被 LN "搅"了一次,深层时梯度逐层衰减。
所有现代 LLM 都用 Pre-LN。你写 MHA 时如果发现训练不稳定,先检查 LN 位置。
更进一步,LLaMA 把 LayerNorm 换成了 RMSNorm——x / RMS(x) * γ,相比 LayerNorm 去掉了减均值步骤,训练略快且效果几乎一致。这是项目 06(decoder block)的主题。
一个反直觉的事
写完 MHA 之后,可以反向问自己:为什么不用一个 256 维的大 head,而非要切成 4 个 64 维?
数学上单头的表达能力严格大于多头——多头是单头的一个约束子集("约束"指多头强制不同子空间不共享投影矩阵,而单头允许任意混合)。但实测多头一直赢,这有点反直觉。
主流解释有三种,每种都没有被严格证明:
- 优化更容易:多个独立子空间各自下降,比一个高维空间整体下降更容易找到好解。Adam 等优化器(优化器 指根据梯度更新参数的算法,Adam 是最常用的一种)在不同子空间能独立调节学习率
- 隐式正则化:强制 head 之间不通信,相当于约束模型不能依赖跨子空间的复杂耦合,泛化反而更好。泛化指模型在没见过的数据上的表现
- 硬件友好:多个小矩阵乘比一个大矩阵乘对 GPU 缓存和 tensor core(tensor core 是 NVIDIA GPU 里专门做矩阵乘的硬件单元)都更友好
工程上"多头比单头好"这件事经过了 8 年(2017 至今)所有模型反复印证,已经是事实,但底层原因仍是研究问题。
把这一节内容串起来
写到这里你应该能在脑子里回答这些问题:
- 单头为什么不够:每对位置只能给一个标量权重,压不下多种语言关系
- 为什么除 √d:让 softmax 输入的方差稳定在 1 量级,避免梯度死掉
- 多头怎么并行:reshape + transpose 把 head 维提到 batch 之后用矩阵乘批量算
- 参数和 FLOP 跟 head 数无关:head 切分只是分配方式,总维度没变
- W_o 不能省:它是不同 head 唯一互相通信的地方
- head_dim 经验值 64~128:再小单 head 容量不够,再大多头收益变薄
- 生产模型用 GQA 而非 MHA:KV cache 显存差 8 倍以上,并发场景下决定能不能跑
下一篇项目 06 把 MHA 和 FFN 拼成完整的 Transformer decoder block,加上 residual、LayerNorm、合理的 dropout 位置,得到一个能堆深的最小单元。
延伸阅读
- Attention Is All You Need 第 3.2 节——MHA 的原始定义
- Are Sixteen Heads Really Better than One?——head 剪枝实验
- Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting——head 功能分类
- GQA: Training Generalized Multi-Query Transformer Models——LLaMA 2 用的 GQA 论文
- On Layer Normalization in the Transformer Architecture——pre-LN vs post-LN 训练稳定性分析
- The Illustrated Transformer (MHA 部分)——图解版
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:项目 05:从单头到多头注意力
本文链接:https://www.sshipanoo.com/blog/ai/llm-roadmap/lab-05-mha/
本文最后一次更新为 天前,文章中的某些内容可能已过时!