🫗 知识蒸馏

简介: 知识蒸馏是一种模型压缩技术,通过将大模型(教师)的知识迁移到小模型(学生),提升小模型性能。核心思想是模仿教师的输出分布或中间特征,常用KL散度和温度机制优化软标签学习,兼顾推理效率与泛化能力,广泛应用于轻量化模型部署。(238字)

🎯 概述
知识蒸馏(Knowledge Distillation,KD) 是一种模型压缩技术,旨在将一个大型、复杂的模型(通常称为 教师模型)的知识迁移到一个较小、更高效的模型(称为 学生模型)上。其核心思想是,学生模型通过模仿教师模型的输出或中间表示,学习到教师模型的知识,达到接近甚至超越教师模型性能的效果。
🤔 为什么需要知识蒸馏?
● 模型压缩: 大模型通常参数量巨大,占用大量计算资源。知识蒸馏可以将大模型的知识浓缩到一个小模型中,从而降低模型的存储和计算成本。
● 加速推理: 小模型的推理速度比大模型快得多,这在实时应用中非常重要。
● 提高泛化能力: 通过学习大模型的知识,小模型可以更好地泛化到未见过的数据上。

🏗️ 常见模型蒸馏策略
传统模型的训练是直接通过真实数据的 硬标签 计算损失的,硬标签即离散的、确定的类别标签,通常直接对应数据的真实类别或目标输出

经典蒸馏方法(Hinton 方法)
训练过程
● 使用传统方法训练一个高性能的教师模型
● 设计相对简单的学生模型结构,具体训练方式如下:
○ 输入样本 x 到教师模型,得到教师模型的软标签输出 $p_t$ (软标签指教师模型对样本的输出概率分布,是原始输出logits的归一化值)
○ 输入同样的样本 x 到学生模型,得到学生模型的软标签输出 $ps$
○ 定义损失函数,让学生的输出概率分布 p 尽可能接近教师模型的输出概率分布 q,同时考虑样本x的真实硬标签 y
损失函数
总损失为软损失与硬损失的加权和:$L = \alpha \times L
{soft} + (1-\alpha) \times L{hard}$
● $\alpha$:权重系数
● $L
{soft}$:软损失,即学生模型输出 $p_s$ 与教师模型软标签 $pt$ 之间KL散度(衡量分布差异,保证对教师模型的拟合),$L{soft} = KL( p_t || ps)$
● $L
{hard}$:硬损失,即学生模型输出与真实硬标签之间的交叉熵损失(保证对真实类别的拟合)
KL散度计算
在机器学习、信息论和概率统计中,KL 散度(Kullback-Leibler Divergence) 是一种衡量两个概率分布之间差异的指标,也被称为相对熵(Relative Entropy)
给定两个概率分布:
● 真实分布 P(ground truth 或 target)
● 近似分布 Q(model output 或 prediction)
KL 散度用来衡量“ 用近似分布 Q 来表示真实分布 P 时所损失的信息量”, 定义如下:
● 离散概率分布:$KL ( P || Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}$
● 连续概率分布:$KL ( P || Q) = \int_x P(x) \log \frac{P(x)}{Q(x)} dx$
KL散度具有以下性质:
● 非负性:$KL(P||Q) \geq 0$,当前仅当对所有x , $P(x) = Q(x)$ 时,$KL(P||Q) = 0$
● 不对称性:$KL(P||Q) \neq KL(Q||P)$, 意味着 “用 Q 近似 P” 和 “用 P 近似 Q” 的损失可能不同
带温度的软标签
教师模型对输入 x 的 logits(未归一化输出)为 $z_t$,,通过带温度的 softmax 生成软标签 $p_t = softmax(z_t / T)$,其中 T 是温度参数( T>1 时分布更平滑,保留更多类别关联信息)。
🤔:为什么要加温度?
↩️:如果不加温度,当某个类别的得分 $z_i$ 原大于其他类别的得分时,softmax归一化后, $p_i$ 接近1,而其他类别的概率接近0,失去了其他类别的概率分布信息,因此引入一个温度参数 T,使输出分布更平滑,保留更多类别关联信息。

特征蒸馏(Feature Distillation)
经典蒸馏仅利用教师的输出层知识,而中间层特征(Transformer 隐藏状态、注意力权重)包含丰富的语义信息。此类方法通过对齐师生中间层特征传递知识。
实现方式:
● 选择关键特征层:从教师模型中选取对任务最关键的中间层,通常是语义信息丰富的深层
● 特征映射与对齐:由于学生模型结构可能与教师不同(如层数、维度更少),需设计 “映射函数”(如线性变换)将学生的特征转换为与教师特征兼容的维度。
● 定义特征损失:通过损失函数衡量师生特征的差异,常用MSE 损失(均方误差)或余弦相似度损失

🎯 蒸馏案例
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def init(self, temperature=4.0, alpha=0.7):
super().init()
self.T = temperature
self.alpha = alpha

def forward(self, student_logits, teacher_logits, labels):
    # 蒸馏损失
    soft_targets = F.softmax(teacher_logits / self.T, dim=-1)
    soft_pred = F.log_softmax(student_logits / self.T, dim=-1)
    distill_loss = F.kl_div(soft_pred, soft_targets, reduction='batchmean') * (self.T ** 2)

    # 学生损失
    student_loss = F.cross_entropy(student_logits, labels)

    return self.alpha * distill_loss + (1 - self.alpha) * student_loss

F.kl_div 介绍
假设 P 为目标分布,Q为正式分布(假设为离散型)
因为 $KL ( P || Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)} = \sum_x P(x)\log{P(x)} - P(x)\log{Q(x)} = \sum_x P(x) ( \log{P(x)} -\log{Q(x)} )$
F.kl_div内部的KL散度的实现逻辑是 $KL(input,target) = \sum target \times (log(target) - input )$
所以输入给F.kl_div的input实际是已经经过log和softmax的logits,而输入给 F.kl_div的target只是softmax过的logits
蒸馏分类
黑盒蒸馏
黑盒蒸馏中,学生模型无法访问教师模型的内部结构或中间输出,仅能获取教师模型的最终预测结果(如输入样本对应的输出概率分布),通过模仿这些 “外部输出” 来学习知识,蒸馏主要依赖教师的最终输出(软标签)作为蒸馏信号,损失函数通常仅基于输出概率的匹配(如 KL 散度损失)。
白盒蒸馏
白盒蒸馏中,学生模型可以直接访问教师模型的内部结构和中间输出(如隐藏层特征、注意力权重、logits 等),并通过模仿这些内部信息来学习教师的知识
🎯 面试重点

  1. 知识蒸馏的三个层次?
  2. 温度参数的作用?
  3. 如何选择蒸馏目标?
  4. 蒸馏与剪枝、量化的区别?
相关文章
|
3月前
|
人工智能 自然语言处理 安全
|
存储 SQL 大数据
dataCompare大数据对比之异源数据对比
dataCompare大数据对比之异源数据对比
956 0
|
3月前
|
机器学习/深度学习 算法 关系型数据库
🎮 强化学习
强化学习(RL)是一种通过智能体与环境交互,以最大化累积奖励为目标的学习方法。核心要素包括状态、动作、奖励和策略,强调试错与延迟奖励。常见算法如Q-learning、PPO、DPO等,广泛应用于决策优化与大模型对齐人类偏好。
|
3月前
|
自然语言处理
🏗️ 主流大模型结构
本文系统梳理主流大模型架构:Encoder-Decoder、Decoder-Only、Encoder-Only与Prefix-Decoder,解析GPT、LLaMA、BERT等代表模型演进与特点,对比参数量、上下文长度等关键指标,深入探讨中文模型优化及面试高频问题,助力全面掌握大模型技术脉络。(238字)
|
3月前
|
安全 C++
📈 模型评估
模型评估涵盖能力、安全与效率三大维度,包括语言理解、知识问答、推理代码等基础能力,对齐性及推理延迟、吞吐量等效率指标。常用MMLU、C-Eval、GSM8K等基准,结合Hugging Face工具实现自动化评估,面试关注幻觉检测、指标设计与人工vs自动权衡。
|
3月前
|
缓存 算法 C++
⚡ 模型推理加速
大模型推理加速关键技术:KV-Cache减少重复计算,连续批处理提升吞吐,投机解码实现2-3倍加速,结合vLLM等工具优化部署。涵盖算法、系统与硬件协同设计,助力高效落地。
|
SQL 自然语言处理 数据库
DAIL-SQL: 发掘LLM的NL2SQL能力
最近,DAIL-SQL在魔搭创空间上线,并在NL2SQL任务上取得了新的SOTA。DAIL-SQL可以更好地利用LLM的NL2SQL能力,本文对其进行详细解读。
|
移动开发 前端开发 数据可视化
React 拖拽布局组件 Drag & Drop Layout
本文介绍了如何在React中构建拖拽布局组件,涵盖基础知识、常见问题及解决方案。首先解释了拖拽操作的三个阶段:开始、过程中和结束。接着推荐了几个常用的拖拽库,如`react-beautiful-dnd`,并详细展示了如何使用该库创建基础拖拽组件,包括安装依赖、初始化容器和处理拖拽结束事件。文章还探讨了常见问题,如拖拽不生效、性能优化、嵌套拖拽和跨浏览器兼容性,并提供了进阶技巧,如自定义样式、多列布局和集成其他UI组件。通过这些内容,读者可以掌握构建高效拖拽布局的方法。
968 16
|
存储 人工智能 Serverless
搭建文生图AI系统
随着人工智能的发展,**文本生成图像(文生图)**技术在广告创意、视觉设计、内容营销等领域应用广泛。阿里云通义千问作为先进的大语言模型,不仅具备强大的文本理解能力,还能与图像生成技术结合,实现根据文本描述自动生成高质量图像。 本博客将展示如何使用通义千问与阿里云的其他产品(如函数计算、API 网关、对象存储 OSS)搭建一个简单的文生图系统,实现用户输入文本并生成相应图像的功能。
1032 6
|
运维 负载均衡 Linux
阿里云轻量服务器最新收费标准与价格参考
阿里云轻量服务器具有灵活的镜像选择、快速上手、简便运维等优势,轻量服务器适合个人开发者和学生用来搭建网站、云端学习等场景使用,2024年截至目前国内地域有60元/月、80元/月等套餐可选,国外地域有24元/月、34元/月、67元/月等套餐可选,目前轻量应用服务器2核2G3M带宽82元1年、2核4G4M带宽298元1年。