模型压缩与知识迁移技术

前言

知识蒸馏(KD)将大型教师模型的知识迁移到小型学生模型,实现模型压缩同时保持性能。本文介绍知识蒸馏的原理与实现。


知识蒸馏概述

核心思想

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

print("知识蒸馏核心概念:")
print("=" * 50)
print("• 教师模型 (Teacher): 大型、高性能模型")
print("• 学生模型 (Student): 小型、高效模型")
print("• 软标签 (Soft Labels): 教师的概率分布输出")
print("• 温度 (Temperature): 控制软标签的平滑程度")
print()
print("为什么软标签有效:")
print("• 包含类别间的相似性信息")
print("• 提供更丰富的监督信号")
print("• 比硬标签(one-hot)信息量更大")

软标签示例

def softmax_with_temperature(logits, temperature=1.0):
    """带温度的softmax"""
    scaled_logits = logits / temperature
    exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
    return exp_logits / np.sum(exp_logits)

# 教师模型的logits
teacher_logits = np.array([5.0, 2.0, 1.0, 0.5])
classes = ['', '', '', '']

print("不同温度下的概率分布:")
print("-" * 50)

temperatures = [1.0, 2.0, 5.0, 10.0]
for T in temperatures:
    probs = softmax_with_temperature(teacher_logits, T)
    print(f"\nT = {T}:")
    for cls, prob in zip(classes, probs):
        bar = '' * int(prob * 30)
        print(f"  {cls}: {prob:.4f} {bar}")

Hinton蒸馏

经典知识蒸馏

蒸馏损失:

\[L = \alpha \cdot L_{CE}(y, p_s) + (1-\alpha) \cdot T^2 \cdot L_{KL}(p_t^{(T)}, p_s^{(T)})\]
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class KnowledgeDistillationLoss(nn.Module):
        """知识蒸馏损失"""
        
        def __init__(self, temperature=4.0, alpha=0.5):
            super().__init__()
            self.T = temperature
            self.alpha = alpha
        
        def forward(self, student_logits, teacher_logits, labels):
            """
            student_logits: 学生模型输出 [batch, num_classes]
            teacher_logits: 教师模型输出 [batch, num_classes]
            labels: 真实标签 [batch]
            """
            # 硬标签损失(交叉熵)
            hard_loss = F.cross_entropy(student_logits, labels)
            
            # 软标签损失(KL散度)
            soft_student = F.log_softmax(student_logits / self.T, dim=1)
            soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
            soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
            
            # T^2是因为梯度会被T缩放
            loss = self.alpha * hard_loss + (1 - self.alpha) * (self.T ** 2) * soft_loss
            
            return loss
    
    # 测试
    batch_size = 4
    num_classes = 10
    
    student_logits = torch.randn(batch_size, num_classes)
    teacher_logits = torch.randn(batch_size, num_classes)
    labels = torch.randint(0, num_classes, (batch_size,))
    
    kd_loss = KnowledgeDistillationLoss(temperature=4.0, alpha=0.5)
    loss = kd_loss(student_logits, teacher_logits, labels)
    print(f"蒸馏损失: {loss.item():.4f}")
    
except ImportError:
    print("PyTorch未安装")

NumPy实现

def kl_divergence(p, q, eps=1e-10):
    """KL散度"""
    return np.sum(p * np.log((p + eps) / (q + eps)))

def knowledge_distillation_loss_numpy(student_logits, teacher_logits, 
                                       labels, temperature=4.0, alpha=0.5):
    """NumPy实现知识蒸馏损失"""
    
    batch_size = student_logits.shape[0]
    
    # 硬标签损失
    student_probs = softmax_with_temperature(student_logits, 1.0)
    hard_loss = -np.mean(np.log(student_probs[np.arange(batch_size), labels] + 1e-10))
    
    # 软标签损失
    soft_losses = []
    for i in range(batch_size):
        soft_student = softmax_with_temperature(student_logits[i], temperature)
        soft_teacher = softmax_with_temperature(teacher_logits[i], temperature)
        soft_losses.append(kl_divergence(soft_teacher, soft_student))
    
    soft_loss = np.mean(soft_losses)
    
    # 总损失
    loss = alpha * hard_loss + (1 - alpha) * (temperature ** 2) * soft_loss
    
    return loss, hard_loss, soft_loss

# 测试
batch_size = 4
num_classes = 5

student_logits = np.random.randn(batch_size, num_classes)
teacher_logits = np.random.randn(batch_size, num_classes)
labels = np.random.randint(0, num_classes, batch_size)

loss, hard, soft = knowledge_distillation_loss_numpy(
    student_logits, teacher_logits, labels
)
print(f"总损失: {loss:.4f}")
print(f"硬标签损失: {hard:.4f}")
print(f"软标签损失: {soft:.4f}")

特征蒸馏

FitNets

try:
    class FeatureDistillationLoss(nn.Module):
        """特征蒸馏损失"""
        
        def __init__(self, student_dim, teacher_dim):
            super().__init__()
            # 如果维度不匹配,添加适配层
            self.adapter = None
            if student_dim != teacher_dim:
                self.adapter = nn.Linear(student_dim, teacher_dim)
        
        def forward(self, student_features, teacher_features):
            """
            蒸馏中间层特征
            """
            if self.adapter is not None:
                student_features = self.adapter(student_features)
            
            # MSE损失
            loss = F.mse_loss(student_features, teacher_features)
            
            return loss
    
    
    class AttentionTransfer(nn.Module):
        """注意力迁移"""
        
        def __init__(self):
            super().__init__()
        
        def attention_map(self, features):
            """计算注意力图"""
            # features: [B, C, H, W]
            # 对通道维度求和的平方
            return torch.pow(features, 2).sum(dim=1)
        
        def forward(self, student_features, teacher_features):
            """
            迁移注意力
            """
            # 计算注意力图
            student_att = self.attention_map(student_features)
            teacher_att = self.attention_map(teacher_features)
            
            # 归一化
            student_att = student_att / student_att.sum(dim=[1, 2], keepdim=True)
            teacher_att = teacher_att / teacher_att.sum(dim=[1, 2], keepdim=True)
            
            # L2损失
            loss = (student_att - teacher_att).pow(2).mean()
            
            return loss
    
    print("特征蒸馏类型:")
    print("  • FitNets: 中间层特征匹配")
    print("  • 注意力迁移: 迁移注意力图")
    print("  • 关系蒸馏: 迁移样本间关系")
    
except NameError:
    print("需要先导入PyTorch")

自蒸馏

模型自身蒸馏

try:
    class SelfDistillation(nn.Module):
        """自蒸馏模型"""
        
        def __init__(self, base_model, num_classes, num_branches=3):
            super().__init__()
            
            self.base_model = base_model
            self.num_branches = num_branches
            
            # 多个辅助分类器(在不同深度)
            self.classifiers = nn.ModuleList([
                nn.Linear(256, num_classes) for _ in range(num_branches)
            ])
        
        def forward(self, x, return_features=False):
            # 获取不同层的特征
            features = self.base_model.get_intermediate_features(x)
            
            # 每个分支的输出
            outputs = [clf(feat) for clf, feat in zip(self.classifiers, features)]
            
            if return_features:
                return outputs, features
            return outputs
        
        def distillation_loss(self, outputs, labels, temperature=3.0):
            """
            深层输出指导浅层
            """
            # 最深层作为教师
            teacher_logits = outputs[-1].detach()
            
            total_loss = 0
            for i, student_logits in enumerate(outputs[:-1]):
                # CE损失
                ce_loss = F.cross_entropy(student_logits, labels)
                
                # KD损失
                soft_student = F.log_softmax(student_logits / temperature, dim=1)
                soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
                kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
                
                total_loss += ce_loss + (temperature ** 2) * kd_loss
            
            # 最深层的CE损失
            total_loss += F.cross_entropy(outputs[-1], labels)
            
            return total_loss
    
    print("自蒸馏优势:")
    print("  • 不需要额外的教师模型")
    print("  • 训练过程中自我改进")
    print("  • 可以加速收敛")
    
except NameError:
    print("需要先导入PyTorch")

完整训练流程

try:
    def train_with_distillation(teacher, student, train_loader, 
                                 epochs=10, temperature=4.0, alpha=0.5):
        """知识蒸馏训练"""
        
        teacher.eval()  # 教师模型固定
        student.train()
        
        optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
        kd_criterion = KnowledgeDistillationLoss(temperature, alpha)
        
        for epoch in range(epochs):
            total_loss = 0
            
            for batch_idx, (data, labels) in enumerate(train_loader):
                optimizer.zero_grad()
                
                # 教师输出(无梯度)
                with torch.no_grad():
                    teacher_logits = teacher(data)
                
                # 学生输出
                student_logits = student(data)
                
                # 蒸馏损失
                loss = kd_criterion(student_logits, teacher_logits, labels)
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        return student
    
    print("训练流程:")
    print("  1. 加载预训练的教师模型")
    print("  2. 初始化学生模型")
    print("  3. 冻结教师模型参数")
    print("  4. 使用蒸馏损失训练学生")
    
except NameError:
    print("需要先导入PyTorch")

蒸馏技术变体

多种蒸馏方法

方法 知识类型 描述
Hinton KD 输出 软标签蒸馏
FitNets 特征 中间层特征匹配
AT 注意力 注意力图迁移
RKD 关系 样本间关系迁移
CRD 对比 对比表示蒸馏
DKD 解耦 目标类和非目标类分开
# 关系蒸馏损失
def relation_distillation_loss(student_features, teacher_features):
    """关系蒸馏:迁移样本间关系"""
    
    # 计算学生样本间距离
    student_dist = np.sqrt(np.sum(
        (student_features[:, None] - student_features[None, :]) ** 2, axis=-1
    ))
    
    # 计算教师样本间距离
    teacher_dist = np.sqrt(np.sum(
        (teacher_features[:, None] - teacher_features[None, :]) ** 2, axis=-1
    ))
    
    # 归一化
    student_dist = student_dist / (student_dist.mean() + 1e-8)
    teacher_dist = teacher_dist / (teacher_dist.mean() + 1e-8)
    
    # Huber损失
    loss = np.mean(np.abs(student_dist - teacher_dist))
    
    return loss

# 测试
batch_size = 8
dim = 64
student_feat = np.random.randn(batch_size, dim)
teacher_feat = np.random.randn(batch_size, dim)

rkd_loss = relation_distillation_loss(student_feat, teacher_feat)
print(f"关系蒸馏损失: {rkd_loss:.4f}")

常见问题

Q1: 温度参数如何选择?

  • 通常T=3-20
  • 较高温度→更平滑的分布→更多暗知识
  • 通过验证集调优

Q2: α参数如何设置?

  • α=0.5是常见选择
  • 数据量少时增大软标签权重
  • 教师模型很强时减小硬标签权重

Q3: 学生模型多小合适?

  • 取决于部署需求
  • 通常保持10-50%的参数量
  • 结构相似性有助于蒸馏

Q4: 蒸馏失败的原因?

  • 容量差距过大
  • 温度设置不当
  • 特征维度不匹配
  • 训练数据不足

总结

概念 描述
软标签 教师模型的概率分布
温度 控制分布平滑程度
暗知识 类别间的隐含关系
特征蒸馏 迁移中间层表示

参考资料

  • Hinton, G. et al. (2015). “Distilling the Knowledge in a Neural Network”
  • Romero, A. et al. (2015). “FitNets: Hints for Thin Deep Nets”
  • Zagoruyko, S. & Komodakis, N. (2017). “Paying More Attention to Attention”
  • Park, W. et al. (2019). “Relational Knowledge Distillation”

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:《 机器学习基础系列——知识蒸馏 》

本文链接:http://localhost:3015/ai/%E7%9F%A5%E8%AF%86%E8%92%B8%E9%A6%8F.html

本文最后一次更新为 天前,文章中的某些内容可能已过时!