知识蒸馏的基本思路

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
模型训练 PAI-DLC,5000CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: 知识蒸馏(Knowledge Distillation)是一种模型压缩方法,在人工智能领域有广泛应用。目前深度学习模型在训练过程中对硬件资源要求较高,例如采用GPU、TPU等硬件进行训练加速。但在模型部署阶段,对于复杂的深度学习模型,要想达到较快的推理速度,部署的硬件成本很高,在边缘终端上特别明显。而知识蒸馏利用较复杂的预训练教师模型,指导轻量级的学生模型训练,将教师模型的知识传递给学生网络,实现模型压缩,减少对部署平台的硬件要求,可提高模型的推理速度。

知识蒸馏的概念出自2015年Hinton发表的Distilling the Knowledge in a Neural Network论文。在分类问题中,模型的输出层采用Softmax函数预测类别概率值。模型的目标是使输出概率值最高的类别尽可能接近于实际真实目标(Hard-target),它的训练过程是对真实值求极大似然。原始数据标签通常采用独热编码。这种情况下,所有负标签被平等对待,但这与实际并不一致,例如在动物图片分类时,如果图片的真实结果是猫,而负样本老虎和蛇与猫相近程度是不同的,显然老虎与猫更像;而当模型预测为蛇时,它比预测为老虎时更差。这说明不同负样本对应的输出概率应该有区别,也就是说不同的负样本对损失函数的贡献是不同的。而软目标(Soft Target)是通过改进Softmax层输出的类别概率,对每个类别都分配概率。
软目标(Soft Target)是在Softmax函数中增加了温度参数T:
image.png

其中,zi是某个类别i的逻辑单元输出值,其值越大,说明结果属于这个类别的可能性就越大。T表示蒸馏的温度或强度,T的值越高,其输出概率分布越趋于平滑,就越放大负标签对应的信息,也就更加关注负标签。当T的值为1时,上式就变成了传统的Softmax函数。采用这种方式进行训练,对样本量要求更少,训练后的模型具有更好的泛化能力。
知识蒸馏的重点是对学生模型进行训练,其实现的过程如下:
1.通过传统方式训练比较复杂的教师模型。
2.训练学生模型,其损失分为蒸馏损失和学生损失两部分:前者借助教师模型的输出作为标签进行损失值计算,后者借助真实标签值计算损失,两个损失加权求和即可得到学生模型的总损失,从而对学生模型进行训练。
3.得到学生模型后,在推理阶段使用传统Softmax进行预测。
学生模型的训练和推理过程如图11-25所示。
在知识蒸馏之前先训练教师模型,然后基于改进的Softmax可得到教师模型已经学习的知识,并用其预测的结果作为标签与学生模型软目标输出的结果比较,计算蒸馏损失,使学生模型学习到正样本与负样本之间的关系信息。学生模型的损失除了蒸馏损失之外,还包括采用传统Softmax得到的硬目标损失,即设置T的值为1,将真实的样本标签与模型预测结果进行比较得到学生损失。
image.png

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
目录
相关文章
|
2月前
|
机器学习/深度学习 算法 Python
【机器学习】面试问答:决策树如何进行剪枝?剪枝的方法有哪些?
文章讨论了决策树的剪枝技术,包括预剪枝和后剪枝的概念、方法以及各自的优缺点。
54 2
|
机器学习/深度学习 算法 Python
|
4月前
|
机器学习/深度学习 算法 前端开发
集成学习思想
**集成学习**是通过结合多个预测模型来创建一个更强大、更鲁棒的系统。它利用了如随机森林、AdaBoost和GBDT等策略。随机森林通过Bootstrap抽样构建多个决策树并用多数投票决定结果,增强模型的多样性。Boosting,如Adaboost,逐步调整样本权重,使后续学习器聚焦于前一轮分类错误的样本,减少偏差。GBDT则通过拟合残差逐步提升预测精度。这些方法通过组合弱学习器形成强学习器,提高了预测准确性和模型的鲁棒性。
|
3月前
|
算法 Python
决策树算法详细介绍原理和实现
决策树算法详细介绍原理和实现
|
4月前
|
机器学习/深度学习 算法 搜索推荐
KNN算法(k近邻算法)原理及总结
KNN算法(k近邻算法)原理及总结
|
5月前
|
机器学习/深度学习 算法 数据可视化
突出最强算法模型——回归算法 !!
突出最强算法模型——回归算法 !!
115 1
|
10月前
|
机器学习/深度学习 算法 数据可视化
决策树算法的原理是什么样的?
决策树算法的原理是什么样的?
223 0
决策树算法的原理是什么样的?
|
5月前
|
算法
朴素贝叶斯典型的三种算法
朴素贝叶斯主要有三种算法:贝努利朴素贝叶斯、高斯贝叶斯和多项式贝叶斯三种算法
|
机器学习/深度学习 人工智能 算法
强化学习基础篇【1】:基础知识点、马尔科夫决策过程、蒙特卡洛策略梯度定理、REINFORCE 算法
强化学习基础篇【1】:基础知识点、马尔科夫决策过程、蒙特卡洛策略梯度定理、REINFORCE 算法
 强化学习基础篇【1】:基础知识点、马尔科夫决策过程、蒙特卡洛策略梯度定理、REINFORCE 算法
|
机器学习/深度学习 数据采集 PyTorch
手写数字识别基本思路
手写数字识别基本思路
146 0