🫗 知识蒸馏

简介: 知识蒸馏是一种模型压缩技术,通过让小模型(学生)模仿大模型(教师)的输出或中间特征,实现性能逼近甚至超越。核心方法包括软标签蒸馏、带温度的Softmax提升信息保留,以及特征层对齐。按信息访问程度分为黑盒与白盒蒸馏,广泛用于加速推理、降低资源消耗,同时提升泛化能力。

🎯 概述
知识蒸馏(Knowledge Distillation,KD) 是一种模型压缩技术,旨在将一个大型、复杂的模型(通常称为 教师模型)的知识迁移到一个较小、更高效的模型(称为 学生模型)上。其核心思想是,学生模型通过模仿教师模型的输出或中间表示,学习到教师模型的知识,达到接近甚至超越教师模型性能的效果。
🤔 为什么需要知识蒸馏?
模型压缩: 大模型通常参数量巨大,占用大量计算资源。知识蒸馏可以将大模型的知识浓缩到一个小模型中,从而降低模型的存储和计算成本。
加速推理: 小模型的推理速度比大模型快得多,这在实时应用中非常重要。
提高泛化能力: 通过学习大模型的知识,小模型可以更好地泛化到未见过的数据上。
🏗️ 常见模型蒸馏策略
传统模型的训练是直接通过真实数据的 硬标签 计算损失的,硬标签即离散的、确定的类别标签,通常直接对应数据的真实类别或目标输出
经典蒸馏方法(Hinton 方法)
训练过程
使用传统方法训练一个高性能的教师模型
设计相对简单的学生模型结构,具体训练方式如下:
输入样本 x 到教师模型,得到教师模型的软标签输出
p
t

(软标签指教师模型对样本的输出概率分布,是原始输出logits的归一化值)
输入同样的样本 x 到学生模型,得到学生模型的软标签输出
p
s

定义损失函数,让学生的输出概率分布 p 尽可能接近教师模型的输出概率分布 q,同时考虑样本x的真实硬标签 y
损失函数
总损失为软损失与硬损失的加权和:
L=α×L
soft

+(1−α)×L
hard

α
:权重系数
L
soft

:软损失,即学生模型输出
p
s

与教师模型软标签
p
t

之间KL散度(衡量分布差异,保证对教师模型的拟合),
L
soft

=KL(p
t

∣∣p
s

)
L
hard

:硬损失,即学生模型输出与真实硬标签之间的交叉熵损失(保证对真实类别的拟合)
KL散度计算
在机器学习、信息论和概率统计中,KL 散度(Kullback-Leibler Divergence) 是一种衡量两个概率分布之间差异的指标,也被称为相对熵(Relative Entropy)
给定两个概率分布:
真实分布 P(ground truth 或 target)
近似分布 Q(model output 或 prediction)
KL 散度用来衡量“ 用近似分布 Q 来表示真实分布 P 时所损失的信息量”, 定义如下:
离散概率分布:
KL(P∣∣Q)=
x


P(x)log
Q(x)
P(x)

连续概率分布:
KL(P∣∣Q)=∫
x

P(x)log
Q(x)
P(x)

dx
KL散度具有以下性质:
非负性:
KL(P∣∣Q)≥0
,当前仅当对所有x ,
P(x)=Q(x)
时,
KL(P∣∣Q)=0
不对称性:
KL(P∣∣Q)

=KL(Q∣∣P)
, 意味着 “用 Q 近似 P” 和 “用 P 近似 Q” 的损失可能不同
带温度的软标签
教师模型对输入 x 的 logits(未归一化输出)为
z
t

,,通过带温度的 softmax 生成软标签
p
t

=softmax(z
t

/T)
,其中 T 是温度参数( T>1 时分布更平滑,保留更多类别关联信息)。
🤔:为什么要加温度?
↩️:如果不加温度,当某个类别的得分
z
i

原大于其他类别的得分时,softmax归一化后,
p
i

接近1,而其他类别的概率接近0,失去了其他类别的概率分布信息,因此引入一个温度参数 T,使输出分布更平滑,保留更多类别关联信息。
特征蒸馏(Feature Distillation)
经典蒸馏仅利用教师的输出层知识,而中间层特征(Transformer 隐藏状态、注意力权重)包含丰富的语义信息。此类方法通过对齐师生中间层特征传递知识。
实现方式:
选择关键特征层:从教师模型中选取对任务最关键的中间层,通常是语义信息丰富的深层
特征映射与对齐:由于学生模型结构可能与教师不同(如层数、维度更少),需设计 “映射函数”(如线性变换)将学生的特征转换为与教师特征兼容的维度。
定义特征损失:通过损失函数衡量师生特征的差异,常用MSE 损失(均方误差)或余弦相似度损失
🎯 蒸馏案例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def init(self, temperature=4.0, alpha=0.7):
super().init()
self.T = temperature
self.alpha = alpha

def forward(self, student_logits, teacher_logits, labels):
    # 蒸馏损失
    soft_targets = F.softmax(teacher_logits / self.T, dim=-1)
    soft_pred = F.log_softmax(student_logits / self.T, dim=-1)
    distill_loss = F.kl_div(soft_pred, soft_targets, reduction='batchmean') * (self.T ** 2)

    # 学生损失
    student_loss = F.cross_entropy(student_logits, labels)

    return self.alpha * distill_loss + (1 - self.alpha) * student_loss

F.kl_div 介绍
假设 P 为目标分布,Q为正式分布(假设为离散型)
因为
KL(P∣∣Q)=
x


P(x)log
Q(x)
P(x)

=
x


P(x)logP(x)−P(x)logQ(x)=
x


P(x)(logP(x)−logQ(x))

F.kl_div内部的KL散度的实现逻辑是
KL(input,target)=∑target×(log(target)−input)

所以输入给F.kl_div的input实际是已经经过log和softmax的logits,而输入给 F.kl_div的target只是softmax过的logits
蒸馏分类
黑盒蒸馏
黑盒蒸馏中,学生模型无法访问教师模型的内部结构或中间输出,仅能获取教师模型的最终预测结果(如输入样本对应的输出概率分布),通过模仿这些 “外部输出” 来学习知识,蒸馏主要依赖教师的最终输出(软标签)作为蒸馏信号,损失函数通常仅基于输出概率的匹配(如 KL 散度损失)。
白盒蒸馏
白盒蒸馏中,学生模型可以直接访问教师模型的内部结构和中间输出(如隐藏层特征、注意力权重、logits 等),并通过模仿这些内部信息来学习教师的知识
🎯 面试重点
知识蒸馏的三个层次?
温度参数的作用?
如何选择蒸馏目标?
蒸馏与剪枝、量化的区别?

相关文章
|
C语言 C++
C/C++ 自定义头文件,及头文件结构详解
还是从"stdio.h"说起,这是C语言中内置的标准库,也就是说,头文件很多时候其实就是一个“库”,类似于代码的仓库,也就是说将某些具有特定功能的常量、宏、函数等归为一个大类,然后放进这个“仓库”,就像stdio.h就是一个标准输入/输出的头文件
534 1
|
分布式计算 资源调度 Hadoop
Hadoop: 启动后发现没有DataNode
Hadoop: 启动后发现没有DataNode
1150 0
Hadoop: 启动后发现没有DataNode
|
5月前
|
人工智能 供应链 安全
智能体开发的学习路径:对标国家职业标准的系统化能力构建
程序员陈凯苦于转型智能体开发,课程零散难入门。直到接触“智能体来了”系统化课程,依《人工智能工程技术人员国家职业标准》分三阶段进阶:1-3月打基础,掌握Python、大模型与数据库;3-6月学架构、意图识别与对话管理,达中级水平;6-12月实战企业级项目,如供应链智能体,契合高级工程师要求。课程融合API开发、安全治理与模型优化,助力从Java开发者成长为AI工程师。
|
5月前
|
数据采集 存储 安全
一文讲清:数据清洗、数据中台、数据仓库、数据治理
企业数据混乱、分析低效?根源在于数据体系不完整。本文详解数据清洗、数据仓库、数据中台与数据治理四大核心概念:从清理脏数据,到统一存储分析,再到敏捷服务业务,最后通过治理保障质量与安全,构建企业数据驱动的完整链条。
一文讲清:数据清洗、数据中台、数据仓库、数据治理
|
机器学习/深度学习 编解码 人工智能
RaptorX、AlphaFold、DeepAccNet、ESMFold…你都掌握了吗?一文总结生物制药必备经典模型(2)
RaptorX、AlphaFold、DeepAccNet、ESMFold…你都掌握了吗?一文总结生物制药必备经典模型
989 0
基于51单片机的proteus数字时钟仿真设计
基于51单片机的proteus数字时钟仿真设计
1125 1
|
8月前
|
敏捷开发 人工智能 监控
任务反馈闭环管理:打造高效执行力的17个关键环节全解析
任务反馈闭环管理是一种确保任务从布置到完成全过程信息透明的管理方法,其核心是通过"计划-执行-反馈-改进"的完整循环,解决传统管理中常见的"任务黑洞"问题。这种机制强调责任明确、流程标准化、反馈及时和持续优化,能够显著提升执行力、团队协同效率和组织的敏捷性。关键环节包括SMART目标设定、标准化执行流程、量化反馈机制和PDCA持续改进。有效的闭环管理需要制度设计、工具支持和流程优化的协同配合,并通过五大KPI(任务完成率、反馈及时率等)进行量化评估。实施闭环管理虽面临员工适应、流程复杂等挑战,但数字化转型和智能化工具的应用正推动其向更高效的方向发展。闭环管理不仅是提升效率的工具,更是促进组织持
856 0
|
存储 JavaScript Python
word文档转成Markdown文档并在Typora免费版添加图床-----想想都很香
word文档转成Markdown文档并在Typora免费版添加图床-----想想都很香
992 0
|
人工智能 运维 算法
引领企业未来数字基础架构浪潮,中国铁塔探索超大规模分布式算力
引领企业未来数字基础架构浪潮,中国铁塔探索超大规模分布式算力
530 14
|
网络协议 Unix Linux
一个.NET开源、快速、低延迟的异步套接字服务器和客户端库
一个.NET开源、快速、低延迟的异步套接字服务器和客户端库
324 4