并行的不是一种关系,而是多组关系

在开始之前:把这一篇要用到的词先讲清楚

为了让读这篇时不用切窗口去查别的资料,把后面会反复出现的术语先放在这里:

  • token:模型看到的最小单位。中文里一个 token 经常对应一个字或一两个字,英文里经常对应一个常用单词或常见前缀/后缀。怎么切是上一个项目 lab-01-tokenizer 讲的
  • embedding:每个 token 在模型内部对应的一个向量(比如长度 256 的一串实数)。模型"看到" token 实际上看到的是这串数字
  • attention:让每个 token 在生成自己的表示时,能"看到"其他 token 并按相关度做加权聚合的机制。项目 04 详细写过
  • Q / K / V(query / key / value):attention 的三个内部向量。直觉上 Q 表示"我在找什么",K 表示"我能被怎么找到",V 表示"如果你找到我,我给你什么内容"
  • head:在 attention 内部并行的一组独立计算单元。一个 head 算一组 Q/K/V,得到一组注意力权重
  • softmax:把一组任意实数变成一组和为 1 的非负数(概率分布)的函数:softmax(x_i) = exp(x_i) / Σ exp(x_j)。值越大的输入对应的输出概率越高
  • dim / hidden size:模型每一层每个 token 的向量长度。常见 256 / 512 / 768 / 4096 等
  • perplexity(困惑度):衡量语言模型预测下一个 token 时"有多不确定"的指标,越小越好。它是 cross entropy loss 的指数:perplexity = exp(loss)
  • LayerNorm / RMSNorm:在每一层之间把向量归一化(让数值大小可控),防止训练时数值爆炸或消失。RMSNorm 是 LayerNorm 的简化版,去掉了减均值步骤
  • FLOP(floating-point operation):浮点运算次数,衡量计算量的常用单位。1 GFLOP = 10 亿次浮点运算

单头到底差在哪

项目 04 写完单头 attention 之后,可以做一个非常直接的对比实验:训练同一个 toy 语言模型(toy 这里指为了教学用的小规模模型,可能就几百万参数、几分钟能训完,不追求生产质量),把 dim 固定(比如 256),让单头 attention 跟 "4 个 head × 每 head 64 维" 的多头跑同样数据,比 perplexity。在任何稍微有点结构的语料上,多头会明显更低。

为什么?单头 attention 的权重矩阵 A = softmax(QK^T / √d) 是一个形状为 (T, T) 的矩阵(T 是序列长度,也就是有多少个 token)。每一行是一个 softmax 分布,给出"位置 i 应该多大权重看位置 j"的概率。关键限制:每对位置 (i, j) 只有一个标量权重

但语言里一个 token 跟前文同时有多种关系:

  • 句法依赖(syntactic dependency):句子里词与词之间的语法连接。例如 "the" 应该看后面的名词,因为定冠词修饰名词
  • 共指(coreference):代词指向哪个名词。例如 "John lost his keys, he can't drive" 里的 "he" 共指 "John"
  • 局部依赖:相邻几个 token 之间的紧密联系,例如固定搭配
  • 长程语义:跨越很多 token 的主题呼应,例如段尾呼应段首

单头要把这几种关系压在同一个标量权重里——要么妥协(每种关系都学一点但都不深入),要么训练过程中轮流学一种,被新的覆盖掉。多头的解法很直接:给每对位置算多个独立的注意力分布,每个分布对应一种"视角"。h 个 head 就是 h 个独立 attention 在同一个 token 序列上并行跑,每个 head 有自己的 Q/K/V、自己的权重分布、自己的输出。

这不是"做 h 次单头然后平均"——每个 head 的 Q/K/V 是从输入投影出来的不同子空间(子空间 指原向量空间的一部分维度,比如 256 维空间切成 4 个 64 维子空间)。子空间不同,看到的东西就不同。

为什么要除 √d——一段非看不可的方差推导

绝大多数教程跳过 / √d 的推导,但理解它之后你会知道为什么 head_dim 不能开太大。

先说一个统计事实:如果一个随机变量 X 服从均值 0、方差 1 的分布(这叫标准正态分布或类似的分布),两个独立的 X 和 Y 相乘 X·Y,结果的方差也是 1。如果你把 d 个独立的乘积加起来:

S = X_1·Y_1 + X_2·Y_2 + ... + X_d·Y_d
Var(S) = d × Var(X·Y) = d
Std(S) = √d

这是问题的根源

在 attention 里,Q 和 K 经过 LayerNorm(归一化层,把每个向量调成均值 0、方差 1)之后,它们的每个元素大致服从均值 0、方差 1 的分布。点积 q·k = Σ q_i × k_i 正好就是 d 项独立乘积之和,所以 Std(q·k) ≈ √d

当 d 大(例如 d=128)时,√d ≈ 11——一些 score 会大到 ±11 量级。

把 ±11 量级的 score 送进 softmax 会发生什么?

import math
softmax_max = math.exp(11) / (math.exp(11) + math.exp(-11) + ...) 
            ≈ 1.0   # 输入 11 的位置占绝对优势
其他位置                            ≈ 0.0

整个分布退化成 one-hotone-hot 向量 指只有一个位置是 1、其余全是 0 的向量)。一旦 attention 权重接近 one-hot,模型只在看一个位置,梯度(梯度指反向传播时用来更新参数的方向信号)只通过这一个位置回传,模型基本学不动。

解决方法是把 score 除以 √d,让它的标准差回到 1 量级:

score' = (q · k) / √d    →    Std(score') ≈ 1

softmax 在 ±1 量级的输入上是平滑的,梯度能通过所有位置正常回传。这就是 √d 这个看似魔法常数的真正来源——它把方差从 d 拉回 1。

这个推导对多头 attention 也成立——每个 head 用的是 head_dim 而不是总 dim:

scores = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim)   # 不是 sqrt(dim)

如果错写成 / sqrt(dim),多头时 head_dim < dim,每个 head 的 score 反而被压得过小(标准差小于 1),softmax 分布太平,attention 不集中。这是 MHA 实现里最常见的 bug 之一

reshape 让多个 head 并行起来

工程实现的核心 trick 是不用 for 循环,而是把 head 维 reshape 到 batch 之后并行。

先解释几个 PyTorch 操作:

  • tensor.view(...):在不复制数据的前提下,把张量看成另一种形状。比如 (2, 6) 可以 view 成 (2, 3, 2)(12,)
  • tensor.transpose(a, b):交换两个维度。形状 (B, T, H, D) transpose(1, 2) 之后变 (B, H, T, D)
  • tensor.contiguous():transpose 之后内存布局不再连续,必须调用 contiguous 把它复制成连续布局,否则后续 view 会报错

dim = 256n_head = 4,那么每个 head 的 head_dim = 64。完整的形状变换流:

x:                       (B,    T, 256)
W_q(x):                  (B,    T, 256)   ← 一次 Linear,包含所有 4 个 head 的 Q
.view(B, T, 4, 64):      (B,    T, 4, 64) ← 把 256 维切成 4 份 × 64 维
.transpose(1, 2):        (B, 4, T, 64)    ← 把 head 维转到第 1 位,变成 batch-like

scores = q @ k.T:        (B, 4, T, T)     ← 前两维 (B, H) 被矩阵乘当成批量并行
softmax(scores):         (B, 4, T, T)
out = scores @ v:        (B, 4, T, 64)

.transpose(1, 2):        (B, T, 4, 64)    ← 把 head 维转回来
.contiguous().view:      (B, T, 256)      ← 4 个 head 的输出拼接回 256 维
W_o(out):                (B, T, 256)      ← 输出投影

关键点:PyTorch 的矩阵乘 @最后两个维度当矩阵前面的维度都当 batch 并行(B, 4, T, 64) @ (B, 4, 64, T) 直接给你 (B, 4, T, T),4 个 head 一次性算完。

完整最小实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_head, causal=True,
                 attn_dropout=0.0, resid_dropout=0.0):
        """
        dim:        总隐藏维度,例如 256
        n_head:     head 数量,dim 必须能被它整除
        causal:     True 表示 decoder 风格(每个位置只看自己 + 历史)
        attn_dropout:  作用在 attention 权重上的 dropout(防过拟合)
        resid_dropout: 作用在输出上的 dropout
        """
        super().__init__()
        assert dim % n_head == 0, f"dim={dim} 必须能被 n_head={n_head} 整除"
        self.dim = dim
        self.n_head = n_head
        self.head_dim = dim // n_head

        # 三个 Linear 一次性算出所有 head 的 Q/K/V,bias=False 是现代实现的默认
        # 因为后面有 LayerNorm/RMSNorm,bias 会被吃掉所以省去
        self.W_q = nn.Linear(dim, dim, bias=False)
        self.W_k = nn.Linear(dim, dim, bias=False)
        self.W_v = nn.Linear(dim, dim, bias=False)
        # 输出投影:让不同 head 的输出有互相组合的机会
        self.W_o = nn.Linear(dim, dim, bias=False)

        self.attn_drop = nn.Dropout(attn_dropout)
        self.resid_drop = nn.Dropout(resid_dropout)
        self.causal = causal

    def forward(self, x, return_attn=False):
        # x: (Batch, Time/序列长度, Channel/向量维度)
        B, T, C = x.shape
        H, D = self.n_head, self.head_dim

        # 1. 投影 + reshape + 转置:把 (B, T, C) 变成 (B, H, T, D)
        #    H 在第 1 维,后面矩阵乘会把 (B, H) 当批量
        q = self.W_q(x).view(B, T, H, D).transpose(1, 2)
        k = self.W_k(x).view(B, T, H, D).transpose(1, 2)
        v = self.W_v(x).view(B, T, H, D).transpose(1, 2)

        # 2. 算 attention scores:(B, H, T, D) @ (B, H, D, T) = (B, H, T, T)
        #    注意是 sqrt(D),不是 sqrt(C)!见上一节方差推导
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(D)

        # 3. 因果 mask:把未来位置的 score 设为 -inf
        #    softmax(-inf) = 0,未来位置就完全看不到
        if self.causal:
            mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
            scores = scores.masked_fill(mask, float('-inf'))

        # 4. softmax + dropout + 加权求和
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_drop(attn)
        out = attn @ v                                # (B, H, T, D)

        # 5. 把 head 维转回去,拼回 (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        # 6. 输出投影 + dropout
        out = self.resid_drop(self.W_o(out))

        if return_attn:
            return out, attn   # attn 用于可视化
        return out

实现里有 5 个新手常踩的坑

  1. W_q 的输出维度是 dim 而不是 head_dim——256 实际上是"4 个 head 各 64 维的 Q"拼起来的,view 一下就切开了。这避免了写 4 个独立 Linear,参数总数也不变
  2. transpose(1, 2) 不可省——只有把 head 维提到 batch 维之后,矩阵乘才会按 head 独立算
  3. .contiguous() 在 transpose 之后必须加一次,否则 view 会报 view size is not compatible with input tensor's size and stride
  4. scale 用 sqrt(head_dim) 不是 sqrt(dim)——见上一节方差推导
  5. W_o 不能省——见后面"head 之间怎么通信"那一节

参数和 FLOP 怎么算

很多人第一次看到 MHA 会担心"开 8 个 head 是不是参数 8 倍"。其实不是,参数量跟 head 数完全无关。

dim = 512,无论 n_head 是 1、4、8 还是 16,参数都是:

组件shape参数量
W_q(dim, dim)dim² = 262144
W_k(dim, dim)dim² = 262144
W_v(dim, dim)dim² = 262144
W_o(dim, dim)dim² = 262144
总计4 × dim² ≈ 1M

多头只是把同一组参数"切分"成多个并行子空间使用——总维度没变,只是从 1 个 512 维空间变成 8 个 64 维空间。

FLOP(前面解释过:浮点运算次数,衡量计算量)也可以精确算。对一次 forward(batch=1,seq=T):

步骤FLOP
三个 Q/K/V 投影 Linear(dim, dim) 各做 T 次3 × T × dim²
Q @ K^T:把 (T, dim)(dim, T) 相乘T × T × dim = T² × dim
attn @ V:把 (T, T)(T, dim) 相乘T² × dim
输出投影 W_o(out)T × dim²
总计4 × T × dim² + 2 × T² × dim

观察两个结论:

  • T 短时(T < dim),FLOP 主导是 4 × T × dim²,跟 T 是线性关系
  • T 长时(T > dim),FLOP 主导是 2 × T² × dim,跟 T 是平方关系——这就是 attention 长上下文 O(T²) 复杂度的来源,也是 FlashAttention(项目 15)要解决的问题
  • n_head 不出现在公式里——MHA 跟单头总 FLOP 完全一致

每个 head 的 head_dim 变小是有代价的(256 → 32)。head_dim 太小时单 head 的表达力会下降。所以 head 数不是越多越好,要跟 dim 一起平衡。看几个真实模型的配置:

模型dimn_headhead_dim
GPT-2 small7681264
LLaMA 7B409632128
LLaMA 70B819264128
GPT-3 175B1228896128

head_dim = 64 或 128 是经验上的稳定值——再小(< 32)单 head 容量不够;再大(> 256)多头收益变薄。

不同 head 真的学到不同东西吗——画出来看

光说不够,把 attention 权重画出来。seabornheatmap 函数可以把一个二维矩阵渲染成颜色块图——颜色深浅对应数值大小:

import matplotlib.pyplot as plt
import seaborn as sns

# 假设已经训好一个有 MHA 的 toy LM
model.eval()
sample = "the quick brown fox jumps over the lazy dog"
ids = torch.tensor([[stoi[c] for c in sample]])   # stoi: 字符到 ID 的字典
with torch.no_grad():
    _, attn = model.mha(model.encode(ids), return_attn=True)
# attn shape: (1, n_head, T, T),每个 (T, T) 是一个 head 的注意力矩阵

n_head = attn.shape[1]
fig, axes = plt.subplots(1, n_head, figsize=(3*n_head, 3))
for h in range(n_head):
    sns.heatmap(
        attn[0, h].cpu().numpy(),
        ax=axes[h],
        cmap='Blues',            # 颜色映射:越蓝表示数值越大
        square=True,
        cbar=False,
        xticklabels=list(sample),
        yticklabels=list(sample),
    )
    axes[h].set_title(f"Head {h}")
plt.tight_layout()
plt.savefig("mha_heads.png", dpi=120)

跑出来你会看到几种典型模式(横轴是"被看的位置 j",纵轴是"看的位置 i"):

  • 对角线 head:颜色集中在 (i, i-1)(i, i-2) 附近,专门关注最近 1~2 个 token——做的是 n-gram 风格的局部依赖n-gram 指连续 n 个 token 的组合,比如 bigram 是相邻两个 token
  • bos 主导 headbos = beginning of sequence,序列开头那个特殊 token):所有位置都盯着开头那个 token——这是 attention sink 现象,大部分训练好的模型里都至少有 1~2 个 head 是这样。"sink" 字面是水槽,意思是"接收 attention 多余权重的吸收器"——具体原因学界还在研究,主流解释是模型需要一个"垃圾桶"放掉那些不知道放哪的注意力权重
  • 散点 head:每个位置都集中看几个特定远处 token,可能在学共指、句法、长程依赖
  • uniform head:权重接近均匀分布(每个位置都差不多)——这个 head 没在工作,是冗余的

最后一种现象不是个例。Michel 等人 2019 年的论文 Are Sixteen Heads Really Better than One? 系统验证过:训练好的 Transformer 里有 20%~40% 的 head 是冗余的——剪掉它们对下游性能影响 < 1%。后续 Voita 等人 2019 Analyzing Multi-Head Self-Attention 把 head 按功能分了 4 类(positional:关注固定位置;syntactic:关注语法相关的位置;rare-words:关注罕见词;其他:功能不明显),证明确实有一部分 head 可以被精确解释,另一部分主要是噪声。

head 之间怎么通信——拆 W_o 看清楚

多头的一个关键设计是 head 之间在 attention 内部完全独立——head 1 的输出不会进入 head 2 的计算。它们的唯一会面在最后的 W_o

可以用线性代数视角看 W_o 的作用。设拼接后的向量 concat_out 是把 h 个 (T, head_dim) 拼起来的 (T, dim)

W_o.shape = (dim, dim)
W_o @ concat_out 写成分块矩阵乘:
  = (W_o,1 | W_o,2 | ... | W_o,h) @ [out_1; out_2; ...; out_h]
  = Σ_j W_o,j @ out_j

其中 W_o,jW_o 切成 h 份后第 j 块,shape 是 (dim, head_dim)。也就是说,W_o 的每一行都是 h 个 head 输出向量的线性组合的系数。这正是不同 head 互相通信的唯一机制——如果没有 W_o,后续层只能看到 h 个互不相关的 head 输出拼起来的"长向量"。

实验 1:去掉 W_o

# forward 里把 self.W_o(out) 改成 out
return out

观察:验证集 loss 会显著上升(典型 +0.2~0.5)。多头退化成"h 个互不通气的小模型并联"。

实验 2:把所有 head 强制学一样

# 通过初始化和梯度 hook 让 W_q / W_k / W_v 的 h 个 head 切片永远相等
# 等价于"h 头但所有头算一样的东西"

观察:跟单头复制 h 份等价,loss 跟单头一致。证明多头的实际价值确实来自 head 之间的差异化

实验 3:n_head 扫描,dim 固定 256

n_headhead_dimval loss
12563.21
21283.05
4642.91
8322.94(接近最优)
16163.08(开始反弹)
3283.30(严重退化)

存在最优 n_head,超过之后 head_dim 太小不够用反而变差。这印证了前面 head_dim 经验值 64~128 的结论。

Self-attention vs cross-attention

MHA 不只用在 self-attention 里。先解释这两个词:

  • self-attention(自注意力):Q/K/V 都来自同一个输入序列。一个 token 在看同一句话里的其他 token
  • cross-attention(交叉注意力):Q 来自一个序列,K/V 来自另一个序列。一个序列的 token 在看另一个序列的 token

原始 Transformer 论文用在机器翻译上,decoder 里有两个 attention:

# decoder 内部
x = self_attention(x)         # Q/K/V 都来自 decoder 自己(看已生成的历史输出)
x = cross_attention(x, enc)   # Q 来自 decoder,K/V 来自 encoder(看待翻译的原文)

代码层面差异极小——只是 K/V 的来源换成另一个张量:

def cross_attention_forward(self, x, context):
    """x: 来自 decoder, context: 来自 encoder"""
    q = self.W_q(x)         # Q 用 decoder 状态
    k = self.W_k(context)   # K/V 用 encoder 状态
    v = self.W_v(context)
    # 后面的 reshape + attention 一模一样
    ...

重要事实:现代主流大模型(GPT 系列、LLaMA、Qwen、DeepSeek)都是 decoder-only 架构——没有 encoder,也就没有 cross-attention,只有 causal self-attention。encoder-decoder 架构(T5、BART、mT5、原始 Transformer 翻译用法)才有 cross-attention。这影响到你后续读论文时怎么理解模型结构。

GQA / MQA——MHA 在推理时的最大问题

现在主流大模型(LLaMA 3、Qwen 2、Mistral 等)几乎都不用纯 MHA 了,而是用 GQA(Grouped Query Attention):Query head 数多,K/V head 数少,多个 Q 共享一组 K/V。

为什么要做这个改动?要先理解 KV cache

KV cache——大模型生成文本是逐 token 进行的(先输出第 1 个 token,再输出第 2 个......)。生成第 N 个 token 时,attention 需要历史所有 token 的 K 和 V。如果每生成一个新 token 都重新算前面所有 K/V,O(T²) 的计算量会爆炸。所以工程上把已经算过的 K/V 缓存起来,每生成一个新 token 只算它自己的新 K/V,跟缓存拼一下用。这块缓存就是 KV cache。

KV cache 的显存占用怎么算?设:

L          模型层数
H_kv       每层的 K/V head 数
D          每个 head 的维度
T          上下文长度(已生成 + 待生成的 token 数)
batch      并发请求数(同时服务多少用户)
bytes      每个数值的字节数(fp16 = 2 字节)

总显存:

mem = 2 × L × H_kv × D × T × batch × bytes

(× 2 是因为 K 和 V 各存一份)

以 LLaMA 2 70B 为例(L=80, D=128),对比两种配置:

MHA 版本(H_kv = 64,全部 head 都各自的 K/V):
  2 × 80 × 64 × 128 × 4096 × 1 × 2 ≈ 10.7 GB    ← 单条请求 4K 上下文

GQA 版本(H_kv = 8,多个 Q 共享 K/V):
  2 × 80 × 8 × 128 × 4096 × 1 × 2 ≈ 1.3 GB

单条请求就省 8 倍显存。如果同时服务 32 个用户(batch = 32):

配置KV cache 总显存一张 H100(80 GB)能装下吗
MHA343 GB装不下,要 5 张卡才够
GQA43 GB单卡有余

这就是为什么所有大模型推理向 GQA 收敛。MQA(Multi-Query Attention)是更激进的版本——所有 Q head 共享 1 组 K/V,cache 最小但质量损失开始明显(一般掉 1~2% 任务准确率)。

GQA 的细节是项目 12 的主题(具体 H_q=32, H_kv=8 的实现 + 质量对比),这里只是让你知道:MHA 是教科书定义,但生产模型基本都换成了 GQA 的变体

一个常被忽略的细节:pre-LN vs post-LN

把 MHA 放进 Transformer block 时,LayerNorm 的放法有两种:

# Post-LN(原始 Transformer 2017 论文)
x = LayerNorm(x + MHA(x))
x = LayerNorm(x + FFN(x))

# Pre-LN(GPT-2 / LLaMA / Qwen 都用)
x = x + MHA(LayerNorm(x))
x = x + FFN(LayerNorm(x))

先解释这两种结构里的 "残差连接"(residual connection):x + f(x) 这种把输入和输出加起来的写法。它最早来自 ResNet(2015),让深层网络更易训——梯度可以从顶层"抄近道"直接流到底层。

Post-LN 和 Pre-LN 的区别是 LayerNorm 放在残差之后还是之前。差异看起来微小,但实测差距很大:

  • Post-LN 训练不稳定:早期 step 容易出现 NaN(无穷大或者无意义的数),需要 learning rate warmup 才能起步。warmup 指训练开始时先用很小的学习率,慢慢升到正常值,避免大梯度把模型一下子搞坏
  • Pre-LN 训练稳定:可以直接用大学习率,warmup 也可以省

为什么?Pre-LN 的残差是 x + 模块(LN(x)),残差路径上没有 LN,梯度从顶层直接流到底层不衰减。Post-LN 的残差被 LN "搅"了一次,深层时梯度逐层衰减。

所有现代 LLM 都用 Pre-LN。你写 MHA 时如果发现训练不稳定,先检查 LN 位置。

更进一步,LLaMA 把 LayerNorm 换成了 RMSNorm——x / RMS(x) * γ,相比 LayerNorm 去掉了减均值步骤,训练略快且效果几乎一致。这是项目 06(decoder block)的主题。

一个反直觉的事

写完 MHA 之后,可以反向问自己:为什么不用一个 256 维的大 head,而非要切成 4 个 64 维?

数学上单头的表达能力严格大于多头——多头是单头的一个约束子集("约束"指多头强制不同子空间不共享投影矩阵,而单头允许任意混合)。但实测多头一直赢,这有点反直觉。

主流解释有三种,每种都没有被严格证明:

  1. 优化更容易:多个独立子空间各自下降,比一个高维空间整体下降更容易找到好解。Adam 等优化器(优化器 指根据梯度更新参数的算法,Adam 是最常用的一种)在不同子空间能独立调节学习率
  2. 隐式正则化:强制 head 之间不通信,相当于约束模型不能依赖跨子空间的复杂耦合,泛化反而更好。泛化指模型在没见过的数据上的表现
  3. 硬件友好:多个小矩阵乘比一个大矩阵乘对 GPU 缓存和 tensor core(tensor core 是 NVIDIA GPU 里专门做矩阵乘的硬件单元)都更友好

工程上"多头比单头好"这件事经过了 8 年(2017 至今)所有模型反复印证,已经是事实,但底层原因仍是研究问题。

把这一节内容串起来

写到这里你应该能在脑子里回答这些问题:

  • 单头为什么不够:每对位置只能给一个标量权重,压不下多种语言关系
  • 为什么除 √d:让 softmax 输入的方差稳定在 1 量级,避免梯度死掉
  • 多头怎么并行:reshape + transpose 把 head 维提到 batch 之后用矩阵乘批量算
  • 参数和 FLOP 跟 head 数无关:head 切分只是分配方式,总维度没变
  • W_o 不能省:它是不同 head 唯一互相通信的地方
  • head_dim 经验值 64~128:再小单 head 容量不够,再大多头收益变薄
  • 生产模型用 GQA 而非 MHA:KV cache 显存差 8 倍以上,并发场景下决定能不能跑

下一篇项目 06 把 MHA 和 FFN 拼成完整的 Transformer decoder block,加上 residual、LayerNorm、合理的 dropout 位置,得到一个能堆深的最小单元。

延伸阅读

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:项目 05:从单头到多头注意力

本文链接:https://www.sshipanoo.com/blog/ai/llm-roadmap/lab-05-mha/

本文最后一次更新为 天前,文章中的某些内容可能已过时!