知识蒸馏的基本思路

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: 知识蒸馏(Knowledge Distillation)是一种模型压缩方法,在人工智能领域有广泛应用。目前深度学习模型在训练过程中对硬件资源要求较高,例如采用GPU、TPU等硬件进行训练加速。但在模型部署阶段,对于复杂的深度学习模型,要想达到较快的推理速度,部署的硬件成本很高,在边缘终端上特别明显。而知识蒸馏利用较复杂的预训练教师模型,指导轻量级的学生模型训练,将教师模型的知识传递给学生网络,实现模型压缩,减少对部署平台的硬件要求,可提高模型的推理速度。

知识蒸馏的概念出自2015年Hinton发表的Distilling the Knowledge in a Neural Network论文。在分类问题中,模型的输出层采用Softmax函数预测类别概率值。模型的目标是使输出概率值最高的类别尽可能接近于实际真实目标(Hard-target),它的训练过程是对真实值求极大似然。原始数据标签通常采用独热编码。这种情况下,所有负标签被平等对待,但这与实际并不一致,例如在动物图片分类时,如果图片的真实结果是猫,而负样本老虎和蛇与猫相近程度是不同的,显然老虎与猫更像;而当模型预测为蛇时,它比预测为老虎时更差。这说明不同负样本对应的输出概率应该有区别,也就是说不同的负样本对损失函数的贡献是不同的。而软目标(Soft Target)是通过改进Softmax层输出的类别概率,对每个类别都分配概率。
软目标(Soft Target)是在Softmax函数中增加了温度参数T:
image.png

其中,zi是某个类别i的逻辑单元输出值,其值越大,说明结果属于这个类别的可能性就越大。T表示蒸馏的温度或强度,T的值越高,其输出概率分布越趋于平滑,就越放大负标签对应的信息,也就更加关注负标签。当T的值为1时,上式就变成了传统的Softmax函数。采用这种方式进行训练,对样本量要求更少,训练后的模型具有更好的泛化能力。
知识蒸馏的重点是对学生模型进行训练,其实现的过程如下:
1.通过传统方式训练比较复杂的教师模型。
2.训练学生模型,其损失分为蒸馏损失和学生损失两部分:前者借助教师模型的输出作为标签进行损失值计算,后者借助真实标签值计算损失,两个损失加权求和即可得到学生模型的总损失,从而对学生模型进行训练。
3.得到学生模型后,在推理阶段使用传统Softmax进行预测。
学生模型的训练和推理过程如图11-25所示。
在知识蒸馏之前先训练教师模型,然后基于改进的Softmax可得到教师模型已经学习的知识,并用其预测的结果作为标签与学生模型软目标输出的结果比较,计算蒸馏损失,使学生模型学习到正样本与负样本之间的关系信息。学生模型的损失除了蒸馏损失之外,还包括采用传统Softmax得到的硬目标损失,即设置T的值为1,将真实的样本标签与模型预测结果进行比较得到学生损失。
image.png

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
目录
相关文章
|
机器学习/深度学习 传感器 编解码
再谈注意力机制 | 运用强化学习实现目标特征提取
再谈注意力机制 | 运用强化学习实现目标特征提取
再谈注意力机制 | 运用强化学习实现目标特征提取
|
8月前
|
机器学习/深度学习 人工智能 运维
[ICLR2024]基于对比稀疏扰动技术的时间序列解释框架ContraLSP
《Explaining Time Series via Contrastive and Locally Sparse Perturbations》被机器学习领域顶会ICLR 2024接收。该论文提出了一种创新的基于扰动技术的时间序列解释框架ContraLSP,该框架主要包含一个学习反事实扰动的目标函数和一个平滑条件下稀疏门结构的压缩器。论文在白盒时序预测,黑盒时序分类等仿真数据,和一个真实时序数据集分类任务中进行了实验,ContraLSP在解释性能上超越了SOTA模型,显著提升了时间序列数据解释的质量。
|
8月前
|
机器学习/深度学习 异构计算
Gradformer: 通过图结构归纳偏差提升自注意力机制的图Transformer
Gradformer,新发布的图Transformer,引入指数衰减掩码和可学习约束,强化自注意力机制,聚焦本地信息并保持全局视野。模型整合归纳偏差,增强图结构建模,且在深层架构中表现稳定。对比14种基线模型,Gradformer在图分类、回归任务中胜出,尤其在NCI1、PROTEINS、MUTAG和CLUSTER数据集上准确率提升明显。此外,它在效率和深层模型处理上也表现出色。尽管依赖MPNN模块和效率优化仍有改进空间,但Gradformer已展现出在图任务的强大潜力。
182 2
|
8月前
|
机器学习/深度学习 数据采集 算法
|
8月前
|
机器学习/深度学习 开发者
论文介绍:基于扩散神经网络生成的时空少样本学习
【2月更文挑战第28天】论文介绍:基于扩散神经网络生成的时空少样本学习
90 1
论文介绍:基于扩散神经网络生成的时空少样本学习
|
8月前
|
人工智能 搜索推荐 物联网
DoRA(权重分解低秩适应):一种新颖的模型微调方法_dora模型
DoRA(权重分解低秩适应):一种新颖的模型微调方法_dora模型
400 0
|
机器学习/深度学习 数据采集 PyTorch
手写数字识别基本思路
手写数字识别基本思路
169 0
|
机器学习/深度学习 并行计算 PyTorch
用什么tricks能让模型训练得更快?先了解下这个问题的第一性原理(2)
用什么tricks能让模型训练得更快?先了解下这个问题的第一性原理
148 0
|
机器学习/深度学习 人工智能 自然语言处理
【Pytorch神经网络理论篇】 10 优化器模块+退化学习率
反向传播的意义在于告诉模型我们需要将权重修改到什么数值可以得到最优解,在开始探索合适权重的过程中,正向传播所生成的结果与实际标签的目标值存在误差,反向传播通过这个误差传递给权重,要求权重进行适当的调整来达到一个合适的输出,最终使得正向传播所预测的结果与标签的目标值的误差达到最小,以上即为反向传播的核心思想
173 0
|
机器学习/深度学习 人工智能 PyTorch
【Pytorch神经网络理论篇】 34 样本均衡+分类模型常见损失函数
Sampler类中有一个派生的权重采样类WeightedRandomSampler,能够在加载数据时,按照指定的概率进行随机顺序采样。
438 0