让每个字自己决定,该回头看前文的哪些部分

Bigram 的天花板:只有一个字的记忆

第三篇那个 Bigram 模型,笨在一个地方:预测下一个字时,它只看前一个字。"今天天气真"五个字,它只拿"真"去查表,前面四个字全扔了。

这就是它的天花板。语言里的信息是连贯的——要接好"今天天气真"后面那个字,得知道前面在聊天气;要写好一个长句的结尾,得记得开头说了什么。一个只有一个字记忆的模型,注定写不通顺。

所以接下来几篇的任务很明确:让模型看得见整段上下文。这一篇登场的注意力机制,就是解决这件事的核心零件,也是 GPT 里 T(Transformer)的灵魂。我们不直接甩出公式,而是从一个最笨的想法出发,一步步把它推出来。

先升级表示:每个字变成一个向量

动手之前,先做一个表示上的升级。

Bigram 里,一个字就是一个号码。但一个光秃秃的号码携带不了多少信息。从这一篇起,我们让每个字用一串数字来表示——一个向量。比如用 32 个数字代表一个字,这串数字就有空间编码"这个字偏正式还是口语、是名词还是动词"之类的特征。这串数字的长度,我们叫它 n_embd(embedding 维度)。

做法还是查表,只是表的形状变了:一张 vocab_size × n_embd 的表,每个字的号码对应其中一行,那一行 n_embd 个数就是这个字的向量。这张表本身也是旋钮,会在训练中被一起拧。

import torch
import torch.nn as nn

n_embd = 32
token_embedding_table = nn.Embedding(vocab_size, n_embd)

# 假设 idx 是一批文字的号码,形状 (B, T)
# 查表后,每个字从一个号码变成一个 n_embd 维向量
x = token_embedding_table(idx)      # 形状 (B, T, n_embd)

从现在起,模型内部流动的就是这种 (B, T, n_embd) 的数据:B 个样本,每个样本 T 个字,每个字是一个 n_embd 维向量。注意力机制要处理的,就是它。

第一个笨办法:把前文平均一下

现在让每个字"看见前文"。最笨、但方向正确的办法是:让每个位置的向量,等于它自己和它之前所有字向量的平均。

第 5 个字的新向量 = 第 1 到第 5 个字向量的平均。这样一来,第 5 个位置的向量里就"掺进"了前文的信息,不再只是它自己。

直接写循环求平均很慢。这里有个漂亮的技巧:平均可以用一次矩阵乘法完成。构造一个下三角矩阵,每一行做归一化,让它乘以字向量,得到的就是逐位置的累计平均。

import torch
from torch.nn import functional as F

T = 8
# 下三角全 1 矩阵:第 i 行只有前 i+1 个位置是 1
tril = torch.tril(torch.ones(T, T))
# 每行归一化,让一行的数加起来等于 1
weights = tril / tril.sum(dim=1, keepdim=True)
print(weights)

打印出来,weights 第 1 行是 [1, 0, 0, ...],第 2 行是 [0.5, 0.5, 0, ...],第 3 行是 [0.33, 0.33, 0.33, 0, ...]。拿它去乘字向量 x,第 i 个位置就自动得到了前 i 个字的平均。

# x 形状 (B, T, n_embd),weights 形状 (T, T)
x_averaged = weights @ x      # 每个位置变成"自己和前文的平均"

到这里,请记住这个加权求和的结构:每个位置的新向量,是把前文所有位置的向量,按一组权重加起来。这个结构,就是注意力的骨架。现在的权重是"平均"——大家一视同仁。注意力要做的,只是把这组权重换得聪明一点。

平均太粗暴:有的字更重要

为什么"平均"不够好?因为它把前文每个字看得一样重。

但真实情况不是这样。预测"今天天气真"后面那个字,"天气"这个词显然比"今"这个字更关键。一个好的机制,应该让模型自己判断:当前这个位置,前文里哪些字重要、哪些不重要,然后重的多看、轻的少看

也就是说,加权求和的那组权重,不该是固定的平均,而应该是模型根据内容动态算出来的。这正是注意力机制要做的事。

自注意力:查询、键、值

注意力怎么算这组权重?它给每个字都安排了三个角色,用三个向量表示,名字分别是查询(query)、键(key)、值(value)。

打个比方,想象一场配对。每个字一方面在"找":我这个位置,想从前文里找什么样的信息——这个需求写在它的查询向量里。另一方面每个字也在"亮牌":我这个字,能提供什么信息——这写在它的键向量里。

当某个位置要决定"前文那个字对我有多重要",它就拿自己的查询,去和那个字的键做匹配。匹配度高,说明那个字正好有我想要的东西,权重就大;匹配度低,权重就小。匹配度用两个向量的点积来算——点积大代表两个向量方向接近,也就是"需求和供给对上了"。

权重定了之后,真正被加权求和的,不是字向量本身,而是每个字的第三个向量——值。键负责"配对",值负责"配对成功后实际交出去的内容",两者分开。

把这套用代码写出来。三个角色各用一个线性层从字向量变换得到:

import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)
B, T, n_embd = 4, 8, 32
head_size = 16          # 查询/键/值向量的长度

x = torch.randn(B, T, n_embd)     # 假装是一批字向量

# 三个线性层,分别产出 查询、键、值
query = nn.Linear(n_embd, head_size, bias=False)
key = nn.Linear(n_embd, head_size, bias=False)
value = nn.Linear(n_embd, head_size, bias=False)

q = query(x)        # (B, T, head_size),每个字"想找什么"
k = key(x)          # (B, T, head_size),每个字"能提供什么"
v = value(x)        # (B, T, head_size),每个字"实际交出的内容"

# 每个字的查询,和所有字的键做点积,得到一张"重要性"表
weights = q @ k.transpose(-2, -1)        # (B, T, T)
weights = weights * head_size ** -0.5    # 缩放,原因下面讲

weights 的形状是 (B, T, T):对每个样本,它是一张 T 行 T 列的表,第 i 行第 j 列代表"第 i 个字对第 j 个字的关注度"。这张表就是我们要的那组动态权重——由内容算出来,不再是死板的平均。

代码里那行乘以 head_size ** -0.5 是缩放。点积的数值会随向量长度变大,数值太大的话,下一步的 softmax 会变得极端——几乎把全部权重压给一个字、其它字归零。除以 head_size 的平方根,把数值拉回温和的范围,让权重分布柔和一些。

不能偷看未来:因果掩码

还差一步,而且这一步至关重要。

我们的模型是要预测下一个字。训练时,第 3 个位置的任务是"根据前 3 个字,预测第 4 个字"。那它在算注意力时,绝对不能看到第 4、第 5 个字——那是答案,看了就等于作弊,训出来的模型一到真实生成(没有答案)就废。

所以要给注意力加一道限制:每个位置只许看自己和它之前的字,不许看后面的。这道限制叫因果掩码(causal mask)。

实现办法又用到了下三角矩阵:把权重表里"属于未来"的那些格子(上三角部分),强行设成负无穷。负无穷经过 softmax 之后会变成 0,等于这些未来位置的权重被彻底清零。

tril = torch.tril(torch.ones(T, T))
# 把上三角(未来位置)填成负无穷
weights = weights.masked_fill(tril == 0, float("-inf"))
# softmax:把每一行的分数换算成加起来为 1 的权重
weights = F.softmax(weights, dim=-1)

经过 masked_fillsoftmaxweights 每一行只有"当前位置及之前"非零,且一行加起来等于 1。这正是一组合法的、只看得见过去的加权系数。

一个完整的注意力头

把上面所有步骤接起来——查询键值、点积、缩放、因果掩码、softmax,最后用权重对值做加权求和——就是一个完整的注意力头(attention head):

# 续用上面的 q, k, v, B, T, head_size
weights = q @ k.transpose(-2, -1) * head_size ** -0.5   # 算重要性
tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float("-inf")) # 屏蔽未来
weights = F.softmax(weights, dim=-1)                    # 换算成权重

out = weights @ v       # 用权重对"值"加权求和

print(f"输出形状:{out.shape}")     # (B, T, head_size)

最后这行 weights @ v,结构和文章开头那个"平均"一模一样——还是加权求和。唯一的、也是全部的进步在于:那组权重不再是僵硬的平均,而是模型根据每个字的查询和键、自己算出来的。重要的字权重大,不相关的字权重小。

这就是注意力机制。每个字不再只盯着前一个字,而是回头扫视整段前文,自己决定该重点看哪几个。Bigram 那个"只有一个字记忆"的天花板,被它彻底捅破了。

本篇要点

  • Bigram 只看前一个字,天花板很低;模型需要看见整段上下文。
  • 从这一篇起,每个字用一个 n_embd 维向量表示,比单个号码能携带更多信息。
  • 让每个位置"看见前文",本质是对前文向量做加权求和;最笨的权重是平均。
  • 注意力把平均换成动态权重:每个字有查询、键、值三个向量,查询和键点积算出关注度。
  • 因果掩码用下三角把"未来位置"的权重清零,保证模型预测时不偷看答案。
  • 一个注意力头 = 查询键值、点积、缩放、掩码、softmax、对值加权求和。

下一篇

一个注意力头能让模型看见上下文,但还不够强。下一篇把多个注意力头、前馈网络、残差连接、层归一化、位置编码这些零件,拼装成一个真正完整的 GPT 结构。

参考资料

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:注意力机制:让模型看见上下文

本文链接:https://www.sshipanoo.com/blog/ai/mini-gpt/05-注意力机制让模型看见上下文/

本文最后一次更新为 天前,文章中的某些内容可能已过时!