自注意力、多头注意力与位置编码
前言
注意力机制允许模型动态地关注输入的不同部分,是Transformer和现代大语言模型的核心组件。
注意力的直觉
人类注意力
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
# 模拟阅读时的注意力分布
def visualize_attention_intuition():
sentence = "The cat sat on the mat"
words = sentence.split()
# 当关注"cat"时,对其他词的注意力
attention_weights = np.array([0.1, 0.5, 0.15, 0.05, 0.05, 0.15])
fig, ax = plt.subplots(figsize=(12, 4))
bars = ax.bar(range(len(words)), attention_weights, color='steelblue')
bars[1].set_color('coral') # 被关注的词
ax.set_xticks(range(len(words)))
ax.set_xticklabels(words, fontsize=12)
ax.set_ylabel('注意力权重')
ax.set_title('阅读"cat"时对其他词的注意力分布')
for i, (w, h) in enumerate(zip(words, attention_weights)):
ax.annotate(f'{h:.2f}', xy=(i, h), ha='center', va='bottom')
plt.tight_layout()
plt.show()
visualize_attention_intuition()
基础注意力机制
Query-Key-Value框架
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力
Q, K, V: (batch_size, seq_len, d_k)
"""
d_k = Q.shape[-1]
# 计算注意力分数
scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
# 应用mask(可选)
if mask is not None:
scores = scores + (mask * -1e9)
# Softmax
attention_weights = softmax(scores, axis=-1)
# 加权求和
output = np.matmul(attention_weights, V)
return output, attention_weights
def softmax(x, axis=-1):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
# 测试
batch_size, seq_len, d_k = 2, 5, 8
Q = np.random.randn(batch_size, seq_len, d_k)
K = np.random.randn(batch_size, seq_len, d_k)
V = np.random.randn(batch_size, seq_len, d_k)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Q, K, V 形状: {Q.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"\n注意力权重示例 (batch=0):")
print(weights[0])
可视化注意力权重
def visualize_attention_weights(weights, tokens_q=None, tokens_k=None):
"""可视化注意力权重矩阵"""
if tokens_q is None:
tokens_q = [f'Q{i}' for i in range(weights.shape[0])]
if tokens_k is None:
tokens_k = [f'K{i}' for i in range(weights.shape[1])]
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(weights, cmap='Blues')
ax.set_xticks(np.arange(len(tokens_k)))
ax.set_yticks(np.arange(len(tokens_q)))
ax.set_xticklabels(tokens_k)
ax.set_yticklabels(tokens_q)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
# 添加数值标签
for i in range(len(tokens_q)):
for j in range(len(tokens_k)):
text = ax.text(j, i, f'{weights[i, j]:.2f}',
ha="center", va="center", color="black" if weights[i, j] < 0.5 else "white")
ax.set_title("注意力权重")
ax.set_xlabel("Key")
ax.set_ylabel("Query")
plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
# 可视化
tokens = ['I', 'love', 'machine', 'learning', '.']
visualize_attention_weights(weights[0], tokens, tokens)
自注意力(Self-Attention)
概念
Q、K、V来自同一个输入序列。
class SelfAttention:
"""自注意力层"""
def __init__(self, d_model, d_k):
self.d_k = d_k
# 线性变换权重
self.W_q = np.random.randn(d_model, d_k) * np.sqrt(2.0 / (d_model + d_k))
self.W_k = np.random.randn(d_model, d_k) * np.sqrt(2.0 / (d_model + d_k))
self.W_v = np.random.randn(d_model, d_k) * np.sqrt(2.0 / (d_model + d_k))
def forward(self, X, mask=None):
"""
X: (batch_size, seq_len, d_model)
"""
# 线性变换得到Q, K, V
Q = X @ self.W_q
K = X @ self.W_k
V = X @ self.W_v
# 计算注意力
output, weights = scaled_dot_product_attention(Q, K, V, mask)
return output, weights
# 测试
d_model, d_k = 64, 32
self_attn = SelfAttention(d_model, d_k)
X = np.random.randn(2, 10, d_model) # 2个样本,10个token,64维
output, weights = self_attn.forward(X)
print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
自注意力 vs 全连接
| 特性 | 全连接 | 自注意力 |
|---|---|---|
| 参数量 | $O(n^2 d)$ | $O(d^2)$ |
| 计算复杂度 | $O(n d^2)$ | $O(n^2 d)$ |
| 长程依赖 | 间接 | 直接 |
多头注意力
原理
并行运行多个注意力头,捕获不同的关系模式:
\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\]class MultiHeadAttention:
"""多头注意力"""
def __init__(self, d_model, n_heads):
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 每个头的权重
self.W_q = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
self.W_k = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
self.W_v = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
self.W_o = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
def split_heads(self, x):
"""
x: (batch, seq_len, d_model)
返回: (batch, n_heads, seq_len, d_k)
"""
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, self.n_heads, self.d_k)
return x.transpose(0, 2, 1, 3)
def concat_heads(self, x):
"""
x: (batch, n_heads, seq_len, d_k)
返回: (batch, seq_len, d_model)
"""
batch_size, _, seq_len, _ = x.shape
x = x.transpose(0, 2, 1, 3)
return x.reshape(batch_size, seq_len, -1)
def forward(self, X, mask=None):
# 线性变换
Q = X @ self.W_q
K = X @ self.W_k
V = X @ self.W_v
# 分割成多头
Q = self.split_heads(Q)
K = self.split_heads(K)
V = self.split_heads(V)
# 计算每个头的注意力
d_k = Q.shape[-1]
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)
if mask is not None:
scores = scores + (mask * -1e9)
weights = softmax(scores, axis=-1)
context = np.matmul(weights, V)
# 合并多头
context = self.concat_heads(context)
# 输出投影
output = context @ self.W_o
return output, weights
# 测试
d_model, n_heads = 64, 8
mha = MultiHeadAttention(d_model, n_heads)
X = np.random.randn(2, 10, d_model)
output, weights = mha.forward(X)
print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}") # (batch, n_heads, seq, seq)
可视化多头注意力
def visualize_multihead_attention(weights, head_idx=None):
"""可视化多头注意力"""
n_heads = weights.shape[1]
if head_idx is None:
# 显示所有头
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()
for i in range(min(n_heads, 8)):
ax = axes[i]
im = ax.imshow(weights[0, i], cmap='Blues')
ax.set_title(f'Head {i+1}')
ax.set_xlabel('Key')
ax.set_ylabel('Query')
plt.suptitle('多头注意力权重', fontsize=14)
plt.tight_layout()
plt.show()
else:
# 显示单个头
visualize_attention_weights(weights[0, head_idx])
visualize_multihead_attention(weights)
位置编码
为什么需要位置编码
自注意力是置换不变的,需要额外的位置信息。
正弦位置编码
\(PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})\) \(PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})\)
def positional_encoding(max_len, d_model):
"""生成正弦位置编码"""
pe = np.zeros((max_len, d_model))
position = np.arange(max_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe
# 生成位置编码
max_len, d_model = 100, 64
pe = positional_encoding(max_len, d_model)
# 可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 热力图
ax = axes[0]
im = ax.imshow(pe[:50, :32], cmap='RdBu', aspect='auto')
ax.set_xlabel('维度')
ax.set_ylabel('位置')
ax.set_title('位置编码热力图')
plt.colorbar(im, ax=ax)
# 特定维度的波形
ax = axes[1]
for dim in [0, 4, 8, 16]:
ax.plot(pe[:50, dim], label=f'dim={dim}')
ax.set_xlabel('位置')
ax.set_ylabel('编码值')
ax.set_title('不同维度的位置编码')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"位置编码形状: {pe.shape}")
可学习位置编码
class LearnedPositionalEncoding:
"""可学习位置编码"""
def __init__(self, max_len, d_model):
self.pe = np.random.randn(max_len, d_model) * 0.02
def forward(self, seq_len):
return self.pe[:seq_len]
# 使用示例
learned_pe = LearnedPositionalEncoding(max_len=512, d_model=64)
pe_slice = learned_pe.forward(seq_len=10)
print(f"可学习位置编码形状: {pe_slice.shape}")
Mask机制
Padding Mask
def create_padding_mask(seq, pad_token=0):
"""创建padding mask"""
mask = (seq == pad_token).astype(float)
# 扩展维度用于注意力计算
return mask[:, np.newaxis, np.newaxis, :]
# 示例
seq = np.array([[1, 2, 3, 0, 0],
[1, 2, 3, 4, 0]]) # 0是padding
padding_mask = create_padding_mask(seq)
print("Padding mask:")
print(padding_mask.squeeze())
Causal Mask(因果掩码)
def create_causal_mask(seq_len):
"""创建因果mask(用于自回归生成)"""
mask = np.triu(np.ones((seq_len, seq_len)), k=1)
return mask
# 示例
causal_mask = create_causal_mask(5)
print("Causal mask:")
print(causal_mask)
# 可视化
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
ax = axes[0]
ax.imshow(1 - create_causal_mask(8), cmap='Greens')
ax.set_title('Causal Mask (绿色=可见)')
ax.set_xlabel('Key位置')
ax.set_ylabel('Query位置')
ax = axes[1]
ax.imshow(1 - create_padding_mask(seq)[0, 0], cmap='Greens')
ax.set_title('Padding Mask (绿色=可见)')
ax.set_xlabel('Key位置')
ax.set_ylabel('Query位置')
plt.tight_layout()
plt.show()
PyTorch实现
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttentionPyTorch(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换
Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
# 合并头
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.W_o(context), attn_weights
# 测试
mha = MultiHeadAttentionPyTorch(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512)
output, weights = mha(x, x, x)
print("PyTorch多头注意力:")
print(f" 输入: {x.shape}")
print(f" 输出: {output.shape}")
print(f" 注意力权重: {weights.shape}")
# 使用内置的MultiheadAttention
mha_builtin = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
output_builtin, weights_builtin = mha_builtin(x, x, x)
print(f"\n内置MultiheadAttention输出: {output_builtin.shape}")
except ImportError:
print("PyTorch未安装")
常见问题
Q1: 为什么要除以$\sqrt{d_k}$?
防止点积过大导致softmax梯度过小。
Q2: 多头注意力有什么好处?
- 捕获不同类型的关系
- 增加表达能力
- 类似CNN的多通道
Q3: 位置编码为什么用正弦函数?
- 可以表示相对位置
- 可以外推到更长序列
- 不需要学习
Q4: 自注意力的计算复杂度?
$O(n^2 d)$,对于长序列是瓶颈,有多种优化方法(Sparse、Linear Attention等)。
总结
| 概念 | 描述 |
|---|---|
| 注意力 | 动态加权聚合信息 |
| 自注意力 | Q、K、V来自同一输入 |
| 多头注意力 | 并行多个注意力头 |
| 位置编码 | 注入序列位置信息 |
参考资料
- Vaswani, A. et al. (2017). “Attention Is All You Need”
- The Illustrated Transformer (Jay Alammar)
- The Annotated Transformer (Harvard NLP)
- CS224N: Attention Mechanisms
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:《 机器学习基础系列——注意力机制 》
本文链接:http://localhost:3015/ai/%E6%B3%A8%E6%84%8F%E5%8A%9B%E6%9C%BA%E5%88%B6.html
本文最后一次更新为 天前,文章中的某些内容可能已过时!