蒸馏不是把模型简单缩小,而是把大模型的判断结构有选择地迁移进更便宜的学生网络
蒸馏解决的不是“训练不了”,而是“上线太贵”
前面几篇一直在围绕两个方向打转。
要么把模型权重压小。
要么把训练方式变轻。
知识蒸馏走的是第三条路。
它不执着于保留原模型全部参数。
而是问一个更直接的问题。
如果大模型已经学会了某种行为模式,能不能把这种模式转移给一个更小的学生模型。
这样做的目的非常明确。
不是为了学术上的优雅。
而是为了让推理成本、延迟和部署门槛一起下降。
教师-学生范式到底在迁移什么
教师-学生范式teacher-student paradigmA teacher-student setup trains a smaller student model to mimic the outputs or internal behavior of a larger teacher model, transferring capability without copying the full parameter count.蒸馏并不是让学生去背老师的参数。
学生也不可能照抄老师的结构。
它迁移的是输出分布、类别关系、排序偏好,或者更深一层的中间表示。
这件事为什么有效。
因为老师给出的不仅是“标准答案”。
还有一整套更细的概率结构。
例如在分类任务里,真实标签只会告诉你哪一类是对的。
但老师模型的 logits 会告诉你:
第一名之外,哪些类别也有一定相似性。
哪些类别虽然错,但错得更接近。
这些信息对小模型来说,往往比一条硬标签更有教学价值。
软标签为什么比硬标签更有信息
软标签soft targetsSoft targets are probability distributions produced by the teacher model. They reveal relative confidence across classes or tokens instead of only the single correct label.传统监督训练里,标签往往是 one-hot。
对就是 1。
错就是 0。
这种信号很干净,但也很粗。
蒸馏希望学生看到更细腻的分布。
比如老师对某个 token 的判断,不是“只有 A 对”。
而是“ A 最可能,B 次之,C 也不是完全不可能”。
这类分布能告诉学生更多关于输出空间几何结构的信息。
尤其是类别之间相近、语言生成存在多种合理答案时,这种结构信息很重要。
温度 T 不是装饰参数
如果直接对老师原始 softmax 分布做模仿,常常会有一个问题。
概率过于尖锐。
最大项几乎压扁其他项。
这时学生能学到的信息仍然有限。
于是蒸馏里通常引入温度 T。
当 T > 1 时,分布会变平。
次优选项的概率被抬起来。
老师“偏向谁、差多少”这类细节更容易显露出来。
这就是经典蒸馏公式里为什么经常同时出现 softmax、log_softmax 和 T^2 缩放项。
它不是经验主义拼接。
而是在控制信息密度与梯度尺度。
KL 散度为什么常被用作蒸馏损失
KL 散度KL divergenceKL divergence measures how different one probability distribution is from another. In distillation it is commonly used to align the student output distribution with the teacher distribution.蒸馏最常见的目标,是让学生输出分布尽量接近老师。
KL 散度正适合做这件事。
它比较的不是单个标签。
而是整个分布。
因此在最经典的蒸馏训练里,损失函数通常会由两部分组成。
- 对真实标签的交叉熵。
- 对老师分布的 KL 蒸馏损失。
前者保证学生不偏离任务定义。
后者则把老师学到的结构信息灌进去。
这也是为什么蒸馏通常比“只拿老师生成一堆伪标签再重训”更有控制力。
蒸馏和微调、量化、剪枝不是一回事
知识蒸馏很容易和另外几类压缩技术混在一起。
实际上它们处理的对象不同。
- 微调主要是在已有模型上继续适配任务。
- 量化主要是减少参数表示位宽。
- 剪枝主要是删掉冗余结构或连接。
- 蒸馏主要是把大模型行为迁移到更小模型。
这意味着蒸馏不必单独存在。
它经常可以和其他方法叠加。
比如先蒸馏得到更小学生模型,再做量化。
或者先有一个 LoRA 微调过的强教师,再把它蒸馏给更小底座。
只要你清楚每一步压缩的是哪一层成本,这条路线就会很自然。
为什么蒸馏在 LLM 里既有价值,也有难点
在分类任务里,蒸馏相对直接。
输出空间固定。
损失定义清楚。
LLM 场景里却更复杂。
因为输出是一个长序列。
每个位置都是一个大词表分布。
而且老师与学生可能 tokenizer、上下文能力、指令格式都不同。
所以 LLM 蒸馏通常至少要回答几个问题。
- 是蒸馏 logits,还是蒸馏 hidden states。
- 用老师自由生成的答案,还是用标准数据集标签。
- 老师和学生词表不完全一致时怎么处理。
- 长序列里的损失如何对齐与截断。
也正因如此,蒸馏在 LLM 里更像系统工程,而不是一段短脚本。
最小可运行的蒸馏训练循环
下面给一个能直接说明思路的最小版本。
它把蒸馏放在 causal LM 上。
假设老师和学生使用相同 tokenizer。
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
teacher_id = "Qwen/Qwen2.5-3B-Instruct"
student_id = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(student_id)
teacher = AutoModelForCausalLM.from_pretrained(
teacher_id,
torch_dtype="auto",
device_map=device,
).eval()
student = AutoModelForCausalLM.from_pretrained(
student_id,
torch_dtype="auto",
device_map=device,
)
optimizer = torch.optim.AdamW(student.parameters(), lr=2e-5)
T = 2.0
alpha = 0.7
batch = tokenizer(
[
"请解释为什么 KV cache 能加速解码。",
"请比较量化和蒸馏的差别。",
],
return_tensors="pt",
padding=True,
truncation=True,
).to(device)
labels = batch["input_ids"].clone()
with torch.no_grad():
teacher_out = teacher(**batch)
student_out = student(**batch)
teacher_logits = teacher_out.logits[:, :-1, :]
student_logits = student_out.logits[:, :-1, :]
shift_labels = labels[:, 1:]
ce_loss = F.cross_entropy(
student_logits.reshape(-1, student_logits.size(-1)),
shift_labels.reshape(-1),
ignore_index=tokenizer.pad_token_id,
)
teacher_probs = F.softmax(teacher_logits / T, dim=-1)
student_log_probs = F.log_softmax(student_logits / T, dim=-1)
kd_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="batchmean",
) * (T ** 2)
loss = alpha * kd_loss + (1 - alpha) * ce_loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("ce_loss =", float(ce_loss))
print("kd_loss =", float(kd_loss))
print("total_loss =", float(loss))
这段代码省略了数据加载、epoch 和验证集。
但关键部件都在了。
老师只前向,不更新。
学生同时接收真实标签和老师分布的监督。
如果用 PyTorch 的 KLDivLoss 写法会更显式
有些团队更喜欢把 KL 损失写得更清楚。
下面是等价但更显式的写法。
import torch
import torch.nn as nn
import torch.nn.functional as F
criterion_kd = nn.KLDivLoss(reduction="batchmean")
teacher_probs = F.softmax(teacher_logits / T, dim=-1)
student_log_probs = F.log_softmax(student_logits / T, dim=-1)
kd_loss = criterion_kd(student_log_probs, teacher_probs) * (T ** 2)
PyTorch 官方教程也采用了类似的温度缩放思路。
如果你需要更稳定地扩展到完整训练循环,建议直接参考官方蒸馏教程的结构。
温度和权重怎么调
蒸馏最常见的超参,通常有两个。
一个是温度 T。
一个是蒸馏损失与真实标签损失的权重。
经验上:
T太低,老师分布太尖,软标签价值不明显。T太高,分布过平,区分度下降。- 蒸馏权重太高,学生可能过度模仿老师而忽略真实标签。
- 蒸馏权重太低,蒸馏又退化成普通监督训练。
因此这不是可以永远照抄的固定参数。
要结合任务、模型差距和数据噪声来调。
蒸馏最容易失败的地方
蒸馏失败并不总是因为方法无效。
很多时候是设置本身不合理。
常见问题包括:
- 教师和学生能力差距过大,学生容量接不住。
- 数据集太小,学生只学到老师的局部习惯。
- tokenizer 不兼容,导致对齐复杂且损失混乱。
- 蒸馏目标只看 logits,却忽略了任务格式和解码行为。
- 评估只看训练损失,没有做真实推理质量验证。
所以蒸馏不应被理解为一种自动压缩按钮。
它仍然需要认真设计数据与评测。
蒸馏适合什么时候做
和量化相比,蒸馏压的是另一层成本
量化通常保留原始模型结构。
它压的是权重表示位宽。
蒸馏则可能连架构和参数规模都一起改掉。
所以二者最大的区别是:
量化更像“把同一辆车做轻量化”。
蒸馏更像“让一辆更小的车学会接近原车的驾驶方式”。
这也决定了蒸馏的上线收益有时会更大。
因为学生模型可以天然更小、更快、更省 KV Cache。
但代价是训练路径更长,效果更依赖数据与教师质量。
和剪枝相比,蒸馏更少碰结构伤口
剪枝直接删结构。
它要求系统承受“删完之后还能不能高效执行”的问题。
如果底层内核、图编译或稀疏支持不成熟,理论参数下降不一定转化为真实收益。
蒸馏则更绕开这类结构执行问题。
你直接训练一个小模型。
部署路径更清楚。
因此在很多工业实践里,蒸馏比激进剪枝更容易转化为真实上线收益。
本篇要点
- 知识蒸馏通过教师-学生范式,把大模型输出分布中的结构信息迁移给小模型。
- 软标签和温度
T的作用,是让学生看到比硬标签更丰富的相对概率关系。 - KL 散度常用来对齐学生与教师分布,再和真实标签交叉熵组合成总损失。
- 蒸馏与微调、量化、剪枝关注的压缩层次不同,常常可以叠加使用。
- LLM 蒸馏更像系统工程,需要同时处理 tokenizer、序列对齐、损失设计和真实业务评测。
下一篇
上一篇讲的是怎样把训练成本压低,这一篇讲的是怎样把大模型能力迁移给小模型。下一篇会回到部署执行栈本身,横向比较 vLLM、TGI、TensorRT-LLM 和 SGLang,看看不同引擎各自把瓶颈压在了哪里。
参考资料
版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。
(采用 CC BY-NC-SA 4.0 许可协议进行授权)
本文标题:知识蒸馏
本文链接:https://www.sshipanoo.com/blog/ai/inference-opt/05-知识蒸馏/
本文最后一次更新为 天前,文章中的某些内容可能已过时!