对比学习与掩码自编码器详解

前言

自监督学习(SSL)通过设计pretext任务从无标注数据中学习表示。本文介绍对比学习、掩码自编码器等主流方法。


自监督学习概述

核心思想

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

print("自监督学习核心概念:")
print("=" * 50)
print("• 无需人工标注")
print("• 从数据本身构造监督信号")
print("• 学习通用表示")
print()
print("主要范式:")
print("• 对比学习: 拉近相似样本,推远不同样本")
print("• 生成式: 重建输入(如MAE)")
print("• 预测式: 预测数据的某些属性")

Pretext任务示例

# 常见的pretext任务
pretext_tasks = {
    '图像': [
        '旋转预测: 预测图像旋转角度(0°, 90°, 180°, 270°)',
        '拼图: 打乱图像块后还原顺序',
        '着色: 从灰度图预测颜色',
        '对比学习: 区分同一图像的不同增强',
        'MAE: 遮挡部分patches后重建'
    ],
    '文本': [
        'MLM: 掩码语言模型(BERT)',
        'NSP: 下一句预测',
        'GPT: 自回归生成',
    ],
    '序列': [
        '预测下一步',
        '重建被掩码的序列'
    ]
}

for domain, tasks in pretext_tasks.items():
    print(f"\n{domain}领域:")
    for task in tasks:
        print(f"{task}")

对比学习

SimCLR框架

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class SimCLR(nn.Module):
        """简化的SimCLR实现"""
        
        def __init__(self, encoder, projection_dim=128):
            super().__init__()
            self.encoder = encoder
            
            # 获取encoder输出维度
            with torch.no_grad():
                dummy = torch.randn(1, 3, 32, 32)
                enc_dim = encoder(dummy).shape[1]
            
            # 投影头
            self.projector = nn.Sequential(
                nn.Linear(enc_dim, 256),
                nn.ReLU(),
                nn.Linear(256, projection_dim)
            )
        
        def forward(self, x):
            # 编码
            h = self.encoder(x)
            # 投影
            z = self.projector(h)
            # L2归一化
            z = F.normalize(z, dim=1)
            return z
    
    
    class NTXentLoss(nn.Module):
        """NT-Xent对比损失"""
        
        def __init__(self, temperature=0.5):
            super().__init__()
            self.temperature = temperature
        
        def forward(self, z_i, z_j):
            """
            z_i, z_j: [batch_size, projection_dim]
            z_i和z_j是同一图像的两个不同增强
            """
            batch_size = z_i.shape[0]
            
            # 拼接所有表示
            z = torch.cat([z_i, z_j], dim=0)  # [2*batch_size, dim]
            
            # 计算相似度矩阵
            sim = torch.mm(z, z.T) / self.temperature  # [2N, 2N]
            
            # 创建标签:正样本对的索引
            labels = torch.arange(batch_size).to(z.device)
            labels = torch.cat([labels + batch_size, labels])  # [2N]
            
            # 掩码对角线(自己和自己)
            mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
            sim = sim.masked_fill(mask, float('-inf'))
            
            # 交叉熵损失
            loss = F.cross_entropy(sim, labels)
            
            return loss
    
    
    # 数据增强
    class SimCLRAugmentation:
        """SimCLR数据增强(伪代码)"""
        
        def __init__(self):
            # 实际使用时需要torchvision.transforms
            self.augmentations = [
                '随机裁剪并resize',
                '随机水平翻转',
                '颜色抖动(亮度、对比度、饱和度、色调)',
                '随机灰度化',
                '高斯模糊'
            ]
        
        def __call__(self, x):
            # 返回同一图像的两个增强版本
            return x, x  # 实际需要应用随机增强
    
    print("SimCLR关键组件:")
    print("  • 编码器: ResNet等backbone")
    print("  • 投影头: MLP将表示映射到对比空间")
    print("  • NT-Xent损失: 对比学习目标")
    print("  • 强数据增强: 创造正样本对")
    
except ImportError:
    print("PyTorch未安装")

从零实现对比学习

def numpy_contrastive_loss(z_i, z_j, temperature=0.5):
    """NumPy实现对比损失"""
    
    batch_size = z_i.shape[0]
    
    # L2归一化
    z_i = z_i / np.linalg.norm(z_i, axis=1, keepdims=True)
    z_j = z_j / np.linalg.norm(z_j, axis=1, keepdims=True)
    
    # 拼接
    z = np.vstack([z_i, z_j])  # [2N, dim]
    
    # 相似度矩阵
    sim = np.dot(z, z.T) / temperature  # [2N, 2N]
    
    # 对角线设为极小值
    np.fill_diagonal(sim, -1e9)
    
    # 正样本对的位置
    # 对于第i个样本,正样本是第i+N个
    # 对于第i+N个样本,正样本是第i个
    
    total_loss = 0
    for i in range(batch_size):
        # 第i个样本的正样本是第i+batch_size个
        pos_idx = i + batch_size
        numerator = np.exp(sim[i, pos_idx])
        denominator = np.sum(np.exp(sim[i, :]))
        loss_i = -np.log(numerator / denominator)
        
        # 第i+batch_size个样本的正样本是第i个
        pos_idx = i
        numerator = np.exp(sim[i + batch_size, pos_idx])
        denominator = np.sum(np.exp(sim[i + batch_size, :]))
        loss_j = -np.log(numerator / denominator)
        
        total_loss += loss_i + loss_j
    
    return total_loss / (2 * batch_size)

# 测试
batch_size = 4
dim = 64
z_i = np.random.randn(batch_size, dim)
z_j = z_i + 0.1 * np.random.randn(batch_size, dim)  # 相似的表示

loss = numpy_contrastive_loss(z_i, z_j)
print(f"对比损失: {loss:.4f}")

MoCo

动量对比学习

try:
    class MoCo(nn.Module):
        """MoCo v2实现"""
        
        def __init__(self, encoder, dim=128, K=4096, m=0.999, T=0.07):
            super().__init__()
            
            self.K = K  # 队列大小
            self.m = m  # 动量系数
            self.T = T  # 温度
            
            # Query编码器
            self.encoder_q = encoder
            # Key编码器(动量更新)
            self.encoder_k = encoder.__class__()
            self.encoder_k.load_state_dict(self.encoder_q.state_dict())
            
            # 冻结key编码器
            for param in self.encoder_k.parameters():
                param.requires_grad = False
            
            # 负样本队列
            self.register_buffer("queue", torch.randn(dim, K))
            self.queue = F.normalize(self.queue, dim=0)
            self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        
        @torch.no_grad()
        def _momentum_update_key_encoder(self):
            """动量更新key编码器"""
            for param_q, param_k in zip(self.encoder_q.parameters(),
                                        self.encoder_k.parameters()):
                param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
        
        @torch.no_grad()
        def _dequeue_and_enqueue(self, keys):
            """更新负样本队列"""
            batch_size = keys.shape[0]
            ptr = int(self.queue_ptr)
            
            # 入队
            self.queue[:, ptr:ptr + batch_size] = keys.T
            ptr = (ptr + batch_size) % self.K
            self.queue_ptr[0] = ptr
        
        def forward(self, x_q, x_k):
            # Query编码
            q = self.encoder_q(x_q)  # [N, C]
            q = F.normalize(q, dim=1)
            
            # Key编码(无梯度)
            with torch.no_grad():
                self._momentum_update_key_encoder()
                k = self.encoder_k(x_k)
                k = F.normalize(k, dim=1)
            
            # 正样本logits
            l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)  # [N, 1]
            
            # 负样本logits
            l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])  # [N, K]
            
            # 拼接
            logits = torch.cat([l_pos, l_neg], dim=1)  # [N, 1+K]
            logits /= self.T
            
            # 标签:正样本在第0位
            labels = torch.zeros(logits.shape[0], dtype=torch.long).to(q.device)
            
            # 更新队列
            self._dequeue_and_enqueue(k)
            
            return logits, labels
    
    print("MoCo关键创新:")
    print("  • 动量编码器: 提供一致的负样本表示")
    print("  • 负样本队列: 支持大batch等效效果")
    print("  • m = 0.999: 缓慢更新key编码器")
    
except NameError:
    print("需要先导入PyTorch")

MAE (Masked Autoencoder)

掩码自编码器

try:
    class PatchEmbed(nn.Module):
        """图像patch嵌入"""
        
        def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
            super().__init__()
            self.patch_size = patch_size
            self.n_patches = (img_size // patch_size) ** 2
            
            self.proj = nn.Conv2d(in_channels, embed_dim, 
                                 kernel_size=patch_size, stride=patch_size)
        
        def forward(self, x):
            # [B, C, H, W] -> [B, embed_dim, H/P, W/P] -> [B, N, embed_dim]
            x = self.proj(x)
            x = x.flatten(2).transpose(1, 2)
            return x
    
    
    class SimpleMAE(nn.Module):
        """简化的MAE实现"""
        
        def __init__(self, img_size=224, patch_size=16, in_channels=3,
                     embed_dim=768, mask_ratio=0.75):
            super().__init__()
            
            self.patch_size = patch_size
            self.mask_ratio = mask_ratio
            n_patches = (img_size // patch_size) ** 2
            
            # Patch嵌入
            self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
            
            # 位置编码
            self.pos_embed = nn.Parameter(torch.randn(1, n_patches, embed_dim) * 0.02)
            
            # 简化的编码器(实际使用ViT)
            self.encoder = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8, batch_first=True),
                num_layers=4
            )
            
            # 解码器
            self.decoder = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=embed_dim, nhead=8, batch_first=True),
                num_layers=2
            )
            
            # 重建头
            self.decoder_pred = nn.Linear(embed_dim, patch_size ** 2 * in_channels)
            
            # 掩码token
            self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        def random_masking(self, x):
            """随机掩码patches"""
            B, N, D = x.shape
            len_keep = int(N * (1 - self.mask_ratio))
            
            # 随机排序
            noise = torch.rand(B, N, device=x.device)
            ids_shuffle = torch.argsort(noise, dim=1)
            ids_restore = torch.argsort(ids_shuffle, dim=1)
            
            # 保留的patches
            ids_keep = ids_shuffle[:, :len_keep]
            x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
            
            # 生成mask
            mask = torch.ones([B, N], device=x.device)
            mask[:, :len_keep] = 0
            mask = torch.gather(mask, dim=1, index=ids_restore)
            
            return x_masked, mask, ids_restore
        
        def forward(self, x):
            # Patch嵌入
            x = self.patch_embed(x)
            x = x + self.pos_embed
            
            # 随机掩码
            x, mask, ids_restore = self.random_masking(x)
            
            # 编码
            x = self.encoder(x)
            
            # 添加mask tokens
            B, N, D = x.shape
            mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1] - N, 1)
            x_ = torch.cat([x, mask_tokens], dim=1)
            x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, D))
            
            # 解码
            x = self.decoder(x)
            
            # 预测pixels
            x = self.decoder_pred(x)
            
            return x, mask
    
    print("MAE关键思想:")
    print("  • 高掩码率: 75%的patches被掩码")
    print("  • 非对称架构: 轻量级解码器")
    print("  • 像素级重建: 预测原始像素值")
    print("  • 效率高: 只编码可见patches")
    
except NameError:
    print("需要先导入PyTorch")

BERT风格的自监督

掩码语言模型

def create_mlm_data(tokens, vocab_size, mask_prob=0.15, mask_token=103):
    """创建MLM训练数据"""
    
    masked_tokens = tokens.copy()
    labels = np.full_like(tokens, -100)  # -100表示不计算损失
    
    # 随机选择要掩码的位置
    mask_indices = np.random.random(len(tokens)) < mask_prob
    
    for i, should_mask in enumerate(mask_indices):
        if should_mask:
            labels[i] = tokens[i]  # 保存原始token
            
            rand = np.random.random()
            if rand < 0.8:
                # 80%替换为[MASK]
                masked_tokens[i] = mask_token
            elif rand < 0.9:
                # 10%替换为随机token
                masked_tokens[i] = np.random.randint(vocab_size)
            # 10%保持不变
    
    return masked_tokens, labels

# 示例
tokens = np.array([101, 2054, 2003, 1037, 3899, 102])  # [CLS] what is a dog [SEP]
masked, labels = create_mlm_data(tokens, vocab_size=30000)

print("原始tokens:", tokens)
print("掩码后:", masked)
print("标签:", labels)

方法对比

方法 类型 关键思想 适用场景
SimCLR 对比学习 同图增强为正样本 视觉表示
MoCo 对比学习 动量编码器+队列 视觉表示
BYOL 非对比 无需负样本 视觉表示
MAE 生成式 重建掩码patches 视觉表示
BERT 生成式 掩码语言建模 NLP
GPT 生成式 自回归预测 NLP

常见问题

Q1: 对比学习为什么需要负样本?

负样本防止模型学习到trivial解(所有输入映射到同一点)。

Q2: MAE为什么掩码率这么高?

  • 图像冗余高,低掩码率太简单
  • 迫使模型学习高级语义

Q3: 自监督学习的优势?

  • 利用海量无标注数据
  • 学习通用表示
  • 减少对标注的依赖

Q4: 如何评估自监督模型?

  • 线性评估:冻结backbone,只训练分类头
  • 微调评估:全模型微调
  • 下游任务迁移

总结

概念 描述
对比学习 拉近正样本,推远负样本
掩码建模 遮挡部分输入后重建
数据增强 创造正样本对的关键
表示学习 学习有意义的特征表示

参考资料

  • Chen, T. et al. (2020). “A Simple Framework for Contrastive Learning”
  • He, K. et al. (2020). “Momentum Contrast for Unsupervised Visual Representation”
  • He, K. et al. (2022). “Masked Autoencoders Are Scalable Vision Learners”
  • Devlin, J. et al. (2019). “BERT: Pre-training of Deep Bidirectional Transformers”

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:《 机器学习基础系列——自监督学习 》

本文链接:http://localhost:3015/ai/%E8%87%AA%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0.html

本文最后一次更新为 天前,文章中的某些内容可能已过时!