自注意力、多头注意力与位置编码

前言

注意力机制允许模型动态地关注输入的不同部分,是Transformer和现代大语言模型的核心组件。


注意力的直觉

人类注意力

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# 模拟阅读时的注意力分布
def visualize_attention_intuition():
    sentence = "The cat sat on the mat"
    words = sentence.split()
    
    # 当关注"cat"时,对其他词的注意力
    attention_weights = np.array([0.1, 0.5, 0.15, 0.05, 0.05, 0.15])
    
    fig, ax = plt.subplots(figsize=(12, 4))
    
    bars = ax.bar(range(len(words)), attention_weights, color='steelblue')
    bars[1].set_color('coral')  # 被关注的词
    
    ax.set_xticks(range(len(words)))
    ax.set_xticklabels(words, fontsize=12)
    ax.set_ylabel('注意力权重')
    ax.set_title('阅读"cat"时对其他词的注意力分布')
    
    for i, (w, h) in enumerate(zip(words, attention_weights)):
        ax.annotate(f'{h:.2f}', xy=(i, h), ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

visualize_attention_intuition()

基础注意力机制

Query-Key-Value框架

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    缩放点积注意力
    Q, K, V: (batch_size, seq_len, d_k)
    """
    d_k = Q.shape[-1]
    
    # 计算注意力分数
    scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
    
    # 应用mask(可选)
    if mask is not None:
        scores = scores + (mask * -1e9)
    
    # Softmax
    attention_weights = softmax(scores, axis=-1)
    
    # 加权求和
    output = np.matmul(attention_weights, V)
    
    return output, attention_weights

def softmax(x, axis=-1):
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

# 测试
batch_size, seq_len, d_k = 2, 5, 8

Q = np.random.randn(batch_size, seq_len, d_k)
K = np.random.randn(batch_size, seq_len, d_k)
V = np.random.randn(batch_size, seq_len, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Q, K, V 形状: {Q.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"\n注意力权重示例 (batch=0):")
print(weights[0])

可视化注意力权重

def visualize_attention_weights(weights, tokens_q=None, tokens_k=None):
    """可视化注意力权重矩阵"""
    
    if tokens_q is None:
        tokens_q = [f'Q{i}' for i in range(weights.shape[0])]
    if tokens_k is None:
        tokens_k = [f'K{i}' for i in range(weights.shape[1])]
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    im = ax.imshow(weights, cmap='Blues')
    
    ax.set_xticks(np.arange(len(tokens_k)))
    ax.set_yticks(np.arange(len(tokens_q)))
    ax.set_xticklabels(tokens_k)
    ax.set_yticklabels(tokens_q)
    
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
    
    # 添加数值标签
    for i in range(len(tokens_q)):
        for j in range(len(tokens_k)):
            text = ax.text(j, i, f'{weights[i, j]:.2f}',
                          ha="center", va="center", color="black" if weights[i, j] < 0.5 else "white")
    
    ax.set_title("注意力权重")
    ax.set_xlabel("Key")
    ax.set_ylabel("Query")
    
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()

# 可视化
tokens = ['I', 'love', 'machine', 'learning', '.']
visualize_attention_weights(weights[0], tokens, tokens)

自注意力(Self-Attention)

概念

Q、K、V来自同一个输入序列。

class SelfAttention:
    """自注意力层"""
    
    def __init__(self, d_model, d_k):
        self.d_k = d_k
        
        # 线性变换权重
        self.W_q = np.random.randn(d_model, d_k) * np.sqrt(2.0 / (d_model + d_k))
        self.W_k = np.random.randn(d_model, d_k) * np.sqrt(2.0 / (d_model + d_k))
        self.W_v = np.random.randn(d_model, d_k) * np.sqrt(2.0 / (d_model + d_k))
    
    def forward(self, X, mask=None):
        """
        X: (batch_size, seq_len, d_model)
        """
        # 线性变换得到Q, K, V
        Q = X @ self.W_q
        K = X @ self.W_k
        V = X @ self.W_v
        
        # 计算注意力
        output, weights = scaled_dot_product_attention(Q, K, V, mask)
        
        return output, weights

# 测试
d_model, d_k = 64, 32
self_attn = SelfAttention(d_model, d_k)

X = np.random.randn(2, 10, d_model)  # 2个样本,10个token,64维
output, weights = self_attn.forward(X)

print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")

自注意力 vs 全连接

特性 全连接 自注意力
参数量 $O(n^2 d)$ $O(d^2)$
计算复杂度 $O(n d^2)$ $O(n^2 d)$
长程依赖 间接 直接

多头注意力

原理

并行运行多个注意力头,捕获不同的关系模式:

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\]
class MultiHeadAttention:
    """多头注意力"""
    
    def __init__(self, d_model, n_heads):
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 每个头的权重
        self.W_q = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
        self.W_k = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
        self.W_v = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
        self.W_o = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
    
    def split_heads(self, x):
        """
        x: (batch, seq_len, d_model)
        返回: (batch, n_heads, seq_len, d_k)
        """
        batch_size, seq_len, _ = x.shape
        x = x.reshape(batch_size, seq_len, self.n_heads, self.d_k)
        return x.transpose(0, 2, 1, 3)
    
    def concat_heads(self, x):
        """
        x: (batch, n_heads, seq_len, d_k)
        返回: (batch, seq_len, d_model)
        """
        batch_size, _, seq_len, _ = x.shape
        x = x.transpose(0, 2, 1, 3)
        return x.reshape(batch_size, seq_len, -1)
    
    def forward(self, X, mask=None):
        # 线性变换
        Q = X @ self.W_q
        K = X @ self.W_k
        V = X @ self.W_v
        
        # 分割成多头
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # 计算每个头的注意力
        d_k = Q.shape[-1]
        scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)
        
        if mask is not None:
            scores = scores + (mask * -1e9)
        
        weights = softmax(scores, axis=-1)
        context = np.matmul(weights, V)
        
        # 合并多头
        context = self.concat_heads(context)
        
        # 输出投影
        output = context @ self.W_o
        
        return output, weights

# 测试
d_model, n_heads = 64, 8
mha = MultiHeadAttention(d_model, n_heads)

X = np.random.randn(2, 10, d_model)
output, weights = mha.forward(X)

print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")  # (batch, n_heads, seq, seq)

可视化多头注意力

def visualize_multihead_attention(weights, head_idx=None):
    """可视化多头注意力"""
    
    n_heads = weights.shape[1]
    
    if head_idx is None:
        # 显示所有头
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        axes = axes.flatten()
        
        for i in range(min(n_heads, 8)):
            ax = axes[i]
            im = ax.imshow(weights[0, i], cmap='Blues')
            ax.set_title(f'Head {i+1}')
            ax.set_xlabel('Key')
            ax.set_ylabel('Query')
        
        plt.suptitle('多头注意力权重', fontsize=14)
        plt.tight_layout()
        plt.show()
    else:
        # 显示单个头
        visualize_attention_weights(weights[0, head_idx])

visualize_multihead_attention(weights)

位置编码

为什么需要位置编码

自注意力是置换不变的,需要额外的位置信息。

正弦位置编码

\(PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})\) \(PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})\)

def positional_encoding(max_len, d_model):
    """生成正弦位置编码"""
    pe = np.zeros((max_len, d_model))
    position = np.arange(max_len)[:, np.newaxis]
    
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    
    return pe

# 生成位置编码
max_len, d_model = 100, 64
pe = positional_encoding(max_len, d_model)

# 可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 热力图
ax = axes[0]
im = ax.imshow(pe[:50, :32], cmap='RdBu', aspect='auto')
ax.set_xlabel('维度')
ax.set_ylabel('位置')
ax.set_title('位置编码热力图')
plt.colorbar(im, ax=ax)

# 特定维度的波形
ax = axes[1]
for dim in [0, 4, 8, 16]:
    ax.plot(pe[:50, dim], label=f'dim={dim}')
ax.set_xlabel('位置')
ax.set_ylabel('编码值')
ax.set_title('不同维度的位置编码')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"位置编码形状: {pe.shape}")

可学习位置编码

class LearnedPositionalEncoding:
    """可学习位置编码"""
    
    def __init__(self, max_len, d_model):
        self.pe = np.random.randn(max_len, d_model) * 0.02
    
    def forward(self, seq_len):
        return self.pe[:seq_len]

# 使用示例
learned_pe = LearnedPositionalEncoding(max_len=512, d_model=64)
pe_slice = learned_pe.forward(seq_len=10)
print(f"可学习位置编码形状: {pe_slice.shape}")

Mask机制

Padding Mask

def create_padding_mask(seq, pad_token=0):
    """创建padding mask"""
    mask = (seq == pad_token).astype(float)
    # 扩展维度用于注意力计算
    return mask[:, np.newaxis, np.newaxis, :]

# 示例
seq = np.array([[1, 2, 3, 0, 0],
                [1, 2, 3, 4, 0]])  # 0是padding

padding_mask = create_padding_mask(seq)
print("Padding mask:")
print(padding_mask.squeeze())

Causal Mask(因果掩码)

def create_causal_mask(seq_len):
    """创建因果mask(用于自回归生成)"""
    mask = np.triu(np.ones((seq_len, seq_len)), k=1)
    return mask

# 示例
causal_mask = create_causal_mask(5)
print("Causal mask:")
print(causal_mask)

# 可视化
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

ax = axes[0]
ax.imshow(1 - create_causal_mask(8), cmap='Greens')
ax.set_title('Causal Mask (绿色=可见)')
ax.set_xlabel('Key位置')
ax.set_ylabel('Query位置')

ax = axes[1]
ax.imshow(1 - create_padding_mask(seq)[0, 0], cmap='Greens')
ax.set_title('Padding Mask (绿色=可见)')
ax.set_xlabel('Key位置')
ax.set_ylabel('Query位置')

plt.tight_layout()
plt.show()

PyTorch实现

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class MultiHeadAttentionPyTorch(nn.Module):
        def __init__(self, d_model, n_heads, dropout=0.1):
            super().__init__()
            self.d_model = d_model
            self.n_heads = n_heads
            self.d_k = d_model // n_heads
            
            self.W_q = nn.Linear(d_model, d_model)
            self.W_k = nn.Linear(d_model, d_model)
            self.W_v = nn.Linear(d_model, d_model)
            self.W_o = nn.Linear(d_model, d_model)
            
            self.dropout = nn.Dropout(dropout)
        
        def forward(self, Q, K, V, mask=None):
            batch_size = Q.size(0)
            
            # 线性变换
            Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
            K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
            V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
            
            # 注意力
            scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
            
            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)
            
            attn_weights = F.softmax(scores, dim=-1)
            attn_weights = self.dropout(attn_weights)
            
            context = torch.matmul(attn_weights, V)
            
            # 合并头
            context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
            
            return self.W_o(context), attn_weights
    
    # 测试
    mha = MultiHeadAttentionPyTorch(d_model=512, n_heads=8)
    x = torch.randn(2, 10, 512)
    output, weights = mha(x, x, x)
    
    print("PyTorch多头注意力:")
    print(f"  输入: {x.shape}")
    print(f"  输出: {output.shape}")
    print(f"  注意力权重: {weights.shape}")
    
    # 使用内置的MultiheadAttention
    mha_builtin = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
    output_builtin, weights_builtin = mha_builtin(x, x, x)
    print(f"\n内置MultiheadAttention输出: {output_builtin.shape}")
    
except ImportError:
    print("PyTorch未安装")

常见问题

Q1: 为什么要除以$\sqrt{d_k}$?

防止点积过大导致softmax梯度过小。

Q2: 多头注意力有什么好处?

  • 捕获不同类型的关系
  • 增加表达能力
  • 类似CNN的多通道

Q3: 位置编码为什么用正弦函数?

  • 可以表示相对位置
  • 可以外推到更长序列
  • 不需要学习

Q4: 自注意力的计算复杂度?

$O(n^2 d)$,对于长序列是瓶颈,有多种优化方法(Sparse、Linear Attention等)。


总结

概念 描述
注意力 动态加权聚合信息
自注意力 Q、K、V来自同一输入
多头注意力 并行多个注意力头
位置编码 注入序列位置信息

参考资料

  • Vaswani, A. et al. (2017). “Attention Is All You Need”
  • The Illustrated Transformer (Jay Alammar)
  • The Annotated Transformer (Harvard NLP)
  • CS224N: Attention Mechanisms

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:《 机器学习基础系列——注意力机制 》

本文链接:http://localhost:3015/ai/%E6%B3%A8%E6%84%8F%E5%8A%9B%E6%9C%BA%E5%88%B6.html

本文最后一次更新为 天前,文章中的某些内容可能已过时!