知识蒸馏的基本思路

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: 知识蒸馏(Knowledge Distillation)是一种模型压缩方法,在人工智能领域有广泛应用。目前深度学习模型在训练过程中对硬件资源要求较高,例如采用GPU、TPU等硬件进行训练加速。但在模型部署阶段,对于复杂的深度学习模型,要想达到较快的推理速度,部署的硬件成本很高,在边缘终端上特别明显。而知识蒸馏利用较复杂的预训练教师模型,指导轻量级的学生模型训练,将教师模型的知识传递给学生网络,实现模型压缩,减少对部署平台的硬件要求,可提高模型的推理速度。

知识蒸馏的概念出自2015年Hinton发表的Distilling the Knowledge in a Neural Network论文。在分类问题中,模型的输出层采用Softmax函数预测类别概率值。模型的目标是使输出概率值最高的类别尽可能接近于实际真实目标(Hard-target),它的训练过程是对真实值求极大似然。原始数据标签通常采用独热编码。这种情况下,所有负标签被平等对待,但这与实际并不一致,例如在动物图片分类时,如果图片的真实结果是猫,而负样本老虎和蛇与猫相近程度是不同的,显然老虎与猫更像;而当模型预测为蛇时,它比预测为老虎时更差。这说明不同负样本对应的输出概率应该有区别,也就是说不同的负样本对损失函数的贡献是不同的。而软目标(Soft Target)是通过改进Softmax层输出的类别概率,对每个类别都分配概率。
软目标(Soft Target)是在Softmax函数中增加了温度参数T:
image.png

其中,zi是某个类别i的逻辑单元输出值,其值越大,说明结果属于这个类别的可能性就越大。T表示蒸馏的温度或强度,T的值越高,其输出概率分布越趋于平滑,就越放大负标签对应的信息,也就更加关注负标签。当T的值为1时,上式就变成了传统的Softmax函数。采用这种方式进行训练,对样本量要求更少,训练后的模型具有更好的泛化能力。
知识蒸馏的重点是对学生模型进行训练,其实现的过程如下:
1.通过传统方式训练比较复杂的教师模型。
2.训练学生模型,其损失分为蒸馏损失和学生损失两部分:前者借助教师模型的输出作为标签进行损失值计算,后者借助真实标签值计算损失,两个损失加权求和即可得到学生模型的总损失,从而对学生模型进行训练。
3.得到学生模型后,在推理阶段使用传统Softmax进行预测。
学生模型的训练和推理过程如图11-25所示。
在知识蒸馏之前先训练教师模型,然后基于改进的Softmax可得到教师模型已经学习的知识,并用其预测的结果作为标签与学生模型软目标输出的结果比较,计算蒸馏损失,使学生模型学习到正样本与负样本之间的关系信息。学生模型的损失除了蒸馏损失之外,还包括采用传统Softmax得到的硬目标损失,即设置T的值为1,将真实的样本标签与模型预测结果进行比较得到学生损失。
image.png

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
目录
相关文章
|
5月前
|
机器学习/深度学习 算法 Python
K最近邻算法:简单高效的分类和回归方法(三)
K最近邻算法:简单高效的分类和回归方法(三)
|
2月前
|
机器学习/深度学习 算法 数据可视化
决策树算法的原理是什么样的?
决策树算法的原理是什么样的?
106 0
决策树算法的原理是什么样的?
|
5月前
|
数据采集 存储 运维
K最近邻算法:简单高效的分类和回归方法
K最近邻算法:简单高效的分类和回归方法
|
5月前
|
机器学习/深度学习 数据采集 并行计算
K最近邻算法:简单高效的分类和回归方法(二)
K最近邻算法:简单高效的分类和回归方法(二)
|
8月前
|
机器学习/深度学习 数据采集 PyTorch
手写数字识别基本思路
手写数字识别基本思路
92 0
|
8月前
|
编解码 算法 数据挖掘
FCOS—分割思想做目标检测
FCOS—分割思想做目标检测
116 0
|
算法
决策树ID3算法和C4.5算法实战
决策树ID3算法和C4.5算法实战
102 0
决策树ID3算法和C4.5算法实战
|
XML 机器学习/深度学习 人工智能
【机器学习】集成学习(Boosting)——梯度提升树(GBDT)算法(理论+图解+公式推导)
【机器学习】集成学习(Boosting)——梯度提升树(GBDT)算法(理论+图解+公式推导)
187 0
【机器学习】集成学习(Boosting)——梯度提升树(GBDT)算法(理论+图解+公式推导)
|
机器学习/深度学习 计算机视觉 异构计算
目标检测:SppNet核心思想
目标检测:SppNet核心思想
98 0
目标检测:SppNet核心思想
|
机器学习/深度学习 算法 PyTorch
知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍
知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍
293 0
知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍

相关产品

  • 人工智能平台 PAI