attention 是集合操作,位置必须显式注入
{{ intro_card( points=[ "实测 attention 没有位置编码时的"我打你 = 你打我"问题", "四种主流位置编码完整实现:learned / sinusoidal / RoPE / ALiBi", "RoPE 的旋转矩阵原理 + 为什么它能优雅支持长上下文外推", "破坏实验:换 / 拿掉位置编码、超出训练长度时分别会怎样", ], audience="想理解为什么 RoPE 在现代 LLM 普及、需要做长上下文外推的工程师", prerequisites="读完项目 04(attention)会更好理解,但本篇会先讲清动机", time="40 分钟", ) }}
项目目标
attention 在数学上是一个集合操作——给一组 query / key / value,输出是它们的加权聚合,结果跟输入顺序无关。但语言显然有顺序:我打你 和 你打我 的 token 集合一样,含义完全相反。
要让 Transformer 知道顺序,必须显式注入位置信息。这个项目实现并对比四种主流方案:
| 方案 | 出处 | 加在哪 | 外推性 | 现状 |
|---|---|---|---|---|
| Learned | BERT、GPT-2 | 加到 input embedding | 不能超过训练长度 | 早期主流 |
| Sinusoidal | 原始 Transformer | 加到 input embedding | 理论可外推但实测一般 | 教学经典 |
| RoPE | RoFormer、Llama、Qwen | 作用于 Q/K(不加到 embedding) | 较好,可配合 scaling | 现代 LLM 主流 |
| ALiBi | BLOOM、MPT | 加到 attention score 作为偏置 | 强,天然支持外推 | 长上下文场景常见 |
背景与原理
不加位置时 attention 是什么样的
scaled dot-product attention:
Attention(Q, K, V) = softmax(Q @ K.T / √d) @ V
如果输入序列 [x1, x2, x3] 换成 [x3, x1, x2],Q / K / V 也跟着换顺序,输出的每一行同样会按相应顺序换——但每个位置的输出只依赖于该位置的 Q 和所有 K/V 的内容,不依赖于"我是第几个"。
所以两个 token 序列只要集合一样,attention 输出就只是行的重排,不是不同的语义。这就是为什么必须加位置。
Learned 位置编码
最直接的方法:再开一个 embedding table,每个位置 ID(0, 1, 2, ...)对应一个可学习向量,加到 token embedding 上:
pos_emb = nn.Embedding(max_seq_len, dim)
x = token_emb(ids) + pos_emb(torch.arange(seq_len))
优点:简单、和 token embedding 同质。 缺点:不能超出训练时见过的最大位置——位置 1024 的向量是训练出来的,问位置 2000 等于查到了未训练的随机行。
Sinusoidal 位置编码
原始 Transformer 论文用的方法:
PE(pos, 2i) = sin(pos / 10000^(2i/dim))
PE(pos, 2i+1) = cos(pos / 10000^(2i/dim))
每个维度对应一个不同频率的正余弦波,位置 pos 在这些波上的取值组成它的位置向量。
优点:理论上能外推到任意位置(公式是封闭的)。 缺点:实测外推效果一般,且现代 LLM 普遍认为它对长程依赖学得不太好。
RoPE(Rotary Position Embedding)
RoPE 是当前主流(Llama 1/2/3、Qwen、Mistral、Deepseek 全部用)。核心思想:不要把位置加到 input embedding 上,而是直接旋转 Q 和 K。
设位置 m 的旋转角度是 θ_m。对每对维度 (2i, 2i+1),应用一个 2D 旋转矩阵:
[q'_2i ] [cos(mθ) -sin(mθ)] [q_2i ]
[q'_2i+1] = [sin(mθ) cos(mθ)] [q_2i+1]
旋转后做 Q @ K.T 时,每一项乘积里位置以差值 (m - n) 的形式出现——attention score 自然依赖相对位置,而不是绝对位置。
为什么这样设计好:
- 不占用 input embedding 容量——位置信息和语义信息解耦
- 天然是相对位置——比绝对位置在文本场景更符合语言直觉(语义依赖距离,不依赖绝对偏移)
- 可以外推——通过
RoPE scaling(位置插值 / NTK / YaRN)可以把训练长度 4K 的模型扩到 32K、100K
ALiBi(Attention with Linear Biases)
更激进:连位置编码都不要,直接在 attention score 上加一个跟距离成正比的负偏置:
score(i, j) = q_i @ k_j - m * |i - j|
m 是每个 head 不同的固定斜率。距离越远,score 越被压低;近的位置占优。
优点:实现极简、外推性极强(训练 1K 推理 16K 都没大问题)。 缺点:表达力受限——不能像 RoPE 那样捕捉复杂的位置关系。
动手实现
先准备一个共享的 attention 计算函数,方便后面四种位置方案对比:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def attn(q, k, v, mask=None, score_bias=None):
"""q/k/v: (B, H, T, D)"""
d = q.shape[-1]
scores = (q @ k.transpose(-2, -1)) / math.sqrt(d)
if score_bias is not None:
scores = scores + score_bias # ALiBi 用
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
return F.softmax(scores, dim=-1) @ v
方案 A:Learned
class LearnedPosModel(nn.Module):
def __init__(self, vocab, dim, max_len=1024):
super().__init__()
self.tok_emb = nn.Embedding(vocab, dim)
self.pos_emb = nn.Embedding(max_len, dim)
def encode(self, ids):
T = ids.size(1)
pos = torch.arange(T, device=ids.device)
return self.tok_emb(ids) + self.pos_emb(pos)
方案 B:Sinusoidal
def sinusoidal_pe(max_len, dim):
pe = torch.zeros(max_len, dim)
pos = torch.arange(0, max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
return pe # shape (max_len, dim)
class SinPosModel(nn.Module):
def __init__(self, vocab, dim, max_len=4096):
super().__init__()
self.tok_emb = nn.Embedding(vocab, dim)
self.register_buffer('pe', sinusoidal_pe(max_len, dim))
def encode(self, ids):
T = ids.size(1)
return self.tok_emb(ids) + self.pe[:T]
方案 C:RoPE
RoPE 的核心是一个旋转函数,作用在 Q 和 K 上(不动 V):
def precompute_freqs(dim, max_len, theta=10000.0):
"""预计算每个 (位置, 维度对) 的旋转角度的 cos/sin"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim//2].float() / dim))
t = torch.arange(max_len)
angles = torch.outer(t, freqs) # (max_len, dim/2)
return torch.cos(angles), torch.sin(angles)
def apply_rope(x, cos, sin):
"""对最后一维做 2D 旋转。x: (..., T, D)"""
x1, x2 = x[..., ::2], x[..., 1::2] # 切成奇偶维
cos = cos[: x.shape[-2]].unsqueeze(0) # 广播到 batch / head
sin = sin[: x.shape[-2]].unsqueeze(0)
rotated = torch.empty_like(x)
rotated[..., ::2] = x1 * cos - x2 * sin
rotated[..., 1::2] = x1 * sin + x2 * cos
return rotated
class RoPEAttention(nn.Module):
def __init__(self, dim, n_head, max_len=4096):
super().__init__()
self.n_head = n_head
self.head_dim = dim // n_head
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.out = nn.Linear(dim, dim, bias=False)
cos, sin = precompute_freqs(self.head_dim, max_len)
self.register_buffer('cos', cos)
self.register_buffer('sin', sin)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
# 切 head: (B, T, H, D) -> (B, H, T, D)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
# 关键:只旋转 Q 和 K,V 不动
q = apply_rope(q, self.cos, self.sin)
k = apply_rope(k, self.cos, self.sin)
y = attn(q, k, v)
return self.out(y.transpose(1, 2).contiguous().view(B, T, C))
方案 D:ALiBi
def alibi_bias(n_head, max_len):
"""每个 head 一个固定斜率 m,bias[i,j] = -m * |i-j|"""
# 标准做法:m 取 2^(-8/n_head * (h+1))
slopes = torch.tensor([2 ** (-8 * (h+1) / n_head) for h in range(n_head)])
pos = torch.arange(max_len)
dist = (pos[None, :] - pos[:, None]).abs().float() # (T, T)
return -slopes[:, None, None] * dist[None] # (H, T, T)
class ALiBiAttention(nn.Module):
def __init__(self, dim, n_head, max_len=4096):
super().__init__()
self.n_head = n_head
self.head_dim = dim // n_head
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.out = nn.Linear(dim, dim, bias=False)
self.register_buffer('alibi', alibi_bias(n_head, max_len))
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
bias = self.alibi[:, :T, :T].unsqueeze(0) # (1, H, T, T)
y = attn(q, k, v, score_bias=bias)
return self.out(y.transpose(1, 2).contiguous().view(B, T, C))
观察指标
在同一个 tiny LM(同样数据、同样网络结构,仅换位置方案)上记录:
| 指标 | 怎么测 |
|---|---|
| 训练 loss / 验证 loss | 同样 step 数下的曲线 |
| 训练长度内的 attention pattern | heatmap,观察是否形成对角线(近邻依赖)+ 散点(长程) |
| 外推:训练长度 256,测 512/1024 时 perplexity | 真正的位置方案差距体现 |
| Needle-in-a-haystack:在长上下文中藏一个事实,问能否找到 | 检验长程依赖能力 |
典型实验结果(256 训练,外推测试):
| 方案 | 训练长度 ppl | 512 ppl | 1024 ppl | 2048 ppl |
|---|---|---|---|---|
| Learned | 3.2 | 崩了(高于 50) | - | - |
| Sinusoidal | 3.3 | 4.1 | 6.8 | 15.0 |
| RoPE | 3.2 | 3.5 | 4.2 | 6.5 |
| RoPE + NTK scaling | 3.2 | 3.4 | 3.7 | 4.5 |
| ALiBi | 3.4 | 3.6 | 3.9 | 4.6 |
可以清楚看到:learned 不能外推,sinusoidal 勉强能但衰减很快,RoPE 加上 scaling、ALiBi 是当前长上下文的两大主流方案。
破坏实验
实验 1:完全去掉位置编码
# 直接用 token embedding 不加位置
x = self.tok_emb(ids)
# 训练 loss 会卡在一个比较高的水平,因为模型分不清"我打你"和"你打我"
可以构造一个 unscramble 任务专门测:给模型一个被打乱的句子,让它输出原句。无位置编码的模型在这个任务上完全学不会。
实验 2:关掉 causal mask,再去位置编码
# 同时移除 mask 和 position
# 模型变成"看一袋词预测下一个",相当于词袋语言模型,性能远差
实验 3:RoPE 中 theta 改成极端值
# 标准 theta=10000,改成 theta=10:
# 高频率震荡导致远距离 token 的旋转角度差异巨大,注意力变成"只看近邻"
# 改成 theta=1000000:旋转太慢,远近难以区分
实验 4:训练长度 256,直接推 4096,不做 scaling
# RoPE 模型在 1024 之内还行,超出后 ppl 飙升
# 这就是为什么 Llama 系列长上下文版本(Llama-2-32k 等)都要做 scaled rope / NTK / YaRN
交付物
pos_learned.py/pos_sin.py/pos_rope.py/pos_alibi.py:四种实现train_compare.py:同一份小语料 + 同一份网络结构 + 四种位置方案各训一次extrapolation_test.py:训练长度 256,测 512/1024/2048 的 perplexity- 一张表:四种方案的训练 ppl / 外推 ppl / 实现复杂度对比
- 一张 RoPE 旋转矩阵的可视化(不同位置在 2D 平面上的指向)
- 200 字短复盘:为什么 RoPE 和 ALiBi 取代了 learned / sinusoidal
与本站其他内容连接
- 项目 02:embedding 与语义几何——位置编码加到 / 旋转的对象就是它产出的向量
- 项目 04:手写 attention——RoPE / ALiBi 都要嵌入 attention 内部
- mini-gpt 06:拼出一个真正的 GPT 结构——位置编码在实际 GPT 里的位置
延伸阅读
- RoFormer 原始论文(RoPE)
- ALiBi 原始论文
- Su Jianlin: 让研究人员爱不释手的 Transformer 位置编码
- YaRN: Efficient Context Window Extension(RoPE scaling)
- The Annotated Transformer(含 sinusoidal 推导)
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:项目 03:位置编码(learned / sinusoidal / RoPE / ALiBi)
本文链接:https://www.sshipanoo.com/blog/ai/llm-roadmap/项目03-位置编码/
本文最后一次更新为 天前,文章中的某些内容可能已过时!