已经是最新一篇文章了!
已经是最后一篇文章了!
模型压缩与知识迁移技术
前言
知识蒸馏(KD)将大型教师模型的知识迁移到小型学生模型,实现模型压缩同时保持性能。本文介绍知识蒸馏的原理与实现。
知识蒸馏概述
核心思想
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
print("知识蒸馏核心概念:")
print("=" * 50)
print("• 教师模型 (Teacher): 大型、高性能模型")
print("• 学生模型 (Student): 小型、高效模型")
print("• 软标签 (Soft Labels): 教师的概率分布输出")
print("• 温度 (Temperature): 控制软标签的平滑程度")
print()
print("为什么软标签有效:")
print("• 包含类别间的相似性信息")
print("• 提供更丰富的监督信号")
print("• 比硬标签(one-hot)信息量更大")
软标签示例
def softmax_with_temperature(logits, temperature=1.0):
"""带温度的softmax"""
scaled_logits = logits / temperature
exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
return exp_logits / np.sum(exp_logits)
# 教师模型的logits
teacher_logits = np.array([5.0, 2.0, 1.0, 0.5])
classes = ['猫', '狗', '鸟', '鱼']
print("不同温度下的概率分布:")
print("-" * 50)
temperatures = [1.0, 2.0, 5.0, 10.0]
for T in temperatures:
probs = softmax_with_temperature(teacher_logits, T)
print(f"\nT = {T}:")
for cls, prob in zip(classes, probs):
bar = '█' * int(prob * 30)
print(f" {cls}: {prob:.4f} {bar}")
Hinton蒸馏
经典知识蒸馏
蒸馏损失:
\[L = \alpha \cdot L_{CE}(y, p_s) + (1-\alpha) \cdot T^2 \cdot L_{KL}(p_t^{(T)}, p_s^{(T)})\]try:
import torch
import torch.nn as nn
import torch.nn.functional as F
class KnowledgeDistillationLoss(nn.Module):
"""知识蒸馏损失"""
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.T = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, labels):
"""
student_logits: 学生模型输出 [batch, num_classes]
teacher_logits: 教师模型输出 [batch, num_classes]
labels: 真实标签 [batch]
"""
# 硬标签损失(交叉熵)
hard_loss = F.cross_entropy(student_logits, labels)
# 软标签损失(KL散度)
soft_student = F.log_softmax(student_logits / self.T, dim=1)
soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
# T^2是因为梯度会被T缩放
loss = self.alpha * hard_loss + (1 - self.alpha) * (self.T ** 2) * soft_loss
return loss
# 测试
batch_size = 4
num_classes = 10
student_logits = torch.randn(batch_size, num_classes)
teacher_logits = torch.randn(batch_size, num_classes)
labels = torch.randint(0, num_classes, (batch_size,))
kd_loss = KnowledgeDistillationLoss(temperature=4.0, alpha=0.5)
loss = kd_loss(student_logits, teacher_logits, labels)
print(f"蒸馏损失: {loss.item():.4f}")
except ImportError:
print("PyTorch未安装")
NumPy实现
def kl_divergence(p, q, eps=1e-10):
"""KL散度"""
return np.sum(p * np.log((p + eps) / (q + eps)))
def knowledge_distillation_loss_numpy(student_logits, teacher_logits,
labels, temperature=4.0, alpha=0.5):
"""NumPy实现知识蒸馏损失"""
batch_size = student_logits.shape[0]
# 硬标签损失
student_probs = softmax_with_temperature(student_logits, 1.0)
hard_loss = -np.mean(np.log(student_probs[np.arange(batch_size), labels] + 1e-10))
# 软标签损失
soft_losses = []
for i in range(batch_size):
soft_student = softmax_with_temperature(student_logits[i], temperature)
soft_teacher = softmax_with_temperature(teacher_logits[i], temperature)
soft_losses.append(kl_divergence(soft_teacher, soft_student))
soft_loss = np.mean(soft_losses)
# 总损失
loss = alpha * hard_loss + (1 - alpha) * (temperature ** 2) * soft_loss
return loss, hard_loss, soft_loss
# 测试
batch_size = 4
num_classes = 5
student_logits = np.random.randn(batch_size, num_classes)
teacher_logits = np.random.randn(batch_size, num_classes)
labels = np.random.randint(0, num_classes, batch_size)
loss, hard, soft = knowledge_distillation_loss_numpy(
student_logits, teacher_logits, labels
)
print(f"总损失: {loss:.4f}")
print(f"硬标签损失: {hard:.4f}")
print(f"软标签损失: {soft:.4f}")
特征蒸馏
FitNets
try:
class FeatureDistillationLoss(nn.Module):
"""特征蒸馏损失"""
def __init__(self, student_dim, teacher_dim):
super().__init__()
# 如果维度不匹配,添加适配层
self.adapter = None
if student_dim != teacher_dim:
self.adapter = nn.Linear(student_dim, teacher_dim)
def forward(self, student_features, teacher_features):
"""
蒸馏中间层特征
"""
if self.adapter is not None:
student_features = self.adapter(student_features)
# MSE损失
loss = F.mse_loss(student_features, teacher_features)
return loss
class AttentionTransfer(nn.Module):
"""注意力迁移"""
def __init__(self):
super().__init__()
def attention_map(self, features):
"""计算注意力图"""
# features: [B, C, H, W]
# 对通道维度求和的平方
return torch.pow(features, 2).sum(dim=1)
def forward(self, student_features, teacher_features):
"""
迁移注意力
"""
# 计算注意力图
student_att = self.attention_map(student_features)
teacher_att = self.attention_map(teacher_features)
# 归一化
student_att = student_att / student_att.sum(dim=[1, 2], keepdim=True)
teacher_att = teacher_att / teacher_att.sum(dim=[1, 2], keepdim=True)
# L2损失
loss = (student_att - teacher_att).pow(2).mean()
return loss
print("特征蒸馏类型:")
print(" • FitNets: 中间层特征匹配")
print(" • 注意力迁移: 迁移注意力图")
print(" • 关系蒸馏: 迁移样本间关系")
except NameError:
print("需要先导入PyTorch")
自蒸馏
模型自身蒸馏
try:
class SelfDistillation(nn.Module):
"""自蒸馏模型"""
def __init__(self, base_model, num_classes, num_branches=3):
super().__init__()
self.base_model = base_model
self.num_branches = num_branches
# 多个辅助分类器(在不同深度)
self.classifiers = nn.ModuleList([
nn.Linear(256, num_classes) for _ in range(num_branches)
])
def forward(self, x, return_features=False):
# 获取不同层的特征
features = self.base_model.get_intermediate_features(x)
# 每个分支的输出
outputs = [clf(feat) for clf, feat in zip(self.classifiers, features)]
if return_features:
return outputs, features
return outputs
def distillation_loss(self, outputs, labels, temperature=3.0):
"""
深层输出指导浅层
"""
# 最深层作为教师
teacher_logits = outputs[-1].detach()
total_loss = 0
for i, student_logits in enumerate(outputs[:-1]):
# CE损失
ce_loss = F.cross_entropy(student_logits, labels)
# KD损失
soft_student = F.log_softmax(student_logits / temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
total_loss += ce_loss + (temperature ** 2) * kd_loss
# 最深层的CE损失
total_loss += F.cross_entropy(outputs[-1], labels)
return total_loss
print("自蒸馏优势:")
print(" • 不需要额外的教师模型")
print(" • 训练过程中自我改进")
print(" • 可以加速收敛")
except NameError:
print("需要先导入PyTorch")
完整训练流程
try:
def train_with_distillation(teacher, student, train_loader,
epochs=10, temperature=4.0, alpha=0.5):
"""知识蒸馏训练"""
teacher.eval() # 教师模型固定
student.train()
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
kd_criterion = KnowledgeDistillationLoss(temperature, alpha)
for epoch in range(epochs):
total_loss = 0
for batch_idx, (data, labels) in enumerate(train_loader):
optimizer.zero_grad()
# 教师输出(无梯度)
with torch.no_grad():
teacher_logits = teacher(data)
# 学生输出
student_logits = student(data)
# 蒸馏损失
loss = kd_criterion(student_logits, teacher_logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
return student
print("训练流程:")
print(" 1. 加载预训练的教师模型")
print(" 2. 初始化学生模型")
print(" 3. 冻结教师模型参数")
print(" 4. 使用蒸馏损失训练学生")
except NameError:
print("需要先导入PyTorch")
蒸馏技术变体
多种蒸馏方法
| 方法 | 知识类型 | 描述 |
|---|---|---|
| Hinton KD | 输出 | 软标签蒸馏 |
| FitNets | 特征 | 中间层特征匹配 |
| AT | 注意力 | 注意力图迁移 |
| RKD | 关系 | 样本间关系迁移 |
| CRD | 对比 | 对比表示蒸馏 |
| DKD | 解耦 | 目标类和非目标类分开 |
# 关系蒸馏损失
def relation_distillation_loss(student_features, teacher_features):
"""关系蒸馏:迁移样本间关系"""
# 计算学生样本间距离
student_dist = np.sqrt(np.sum(
(student_features[:, None] - student_features[None, :]) ** 2, axis=-1
))
# 计算教师样本间距离
teacher_dist = np.sqrt(np.sum(
(teacher_features[:, None] - teacher_features[None, :]) ** 2, axis=-1
))
# 归一化
student_dist = student_dist / (student_dist.mean() + 1e-8)
teacher_dist = teacher_dist / (teacher_dist.mean() + 1e-8)
# Huber损失
loss = np.mean(np.abs(student_dist - teacher_dist))
return loss
# 测试
batch_size = 8
dim = 64
student_feat = np.random.randn(batch_size, dim)
teacher_feat = np.random.randn(batch_size, dim)
rkd_loss = relation_distillation_loss(student_feat, teacher_feat)
print(f"关系蒸馏损失: {rkd_loss:.4f}")
常见问题
Q1: 温度参数如何选择?
- 通常T=3-20
- 较高温度→更平滑的分布→更多暗知识
- 通过验证集调优
Q2: α参数如何设置?
- α=0.5是常见选择
- 数据量少时增大软标签权重
- 教师模型很强时减小硬标签权重
Q3: 学生模型多小合适?
- 取决于部署需求
- 通常保持10-50%的参数量
- 结构相似性有助于蒸馏
Q4: 蒸馏失败的原因?
- 容量差距过大
- 温度设置不当
- 特征维度不匹配
- 训练数据不足
总结
| 概念 | 描述 |
|---|---|
| 软标签 | 教师模型的概率分布 |
| 温度 | 控制分布平滑程度 |
| 暗知识 | 类别间的隐含关系 |
| 特征蒸馏 | 迁移中间层表示 |
参考资料
- Hinton, G. et al. (2015). “Distilling the Knowledge in a Neural Network”
- Romero, A. et al. (2015). “FitNets: Hints for Thin Deep Nets”
- Zagoruyko, S. & Komodakis, N. (2017). “Paying More Attention to Attention”
- Park, W. et al. (2019). “Relational Knowledge Distillation”
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:《 机器学习基础系列——知识蒸馏 》
本文链接:http://localhost:3015/ai/%E7%9F%A5%E8%AF%86%E8%92%B8%E9%A6%8F.html
本文最后一次更新为 天前,文章中的某些内容可能已过时!