知识蒸馏的基本思路

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: 知识蒸馏(Knowledge Distillation)是一种模型压缩方法,在人工智能领域有广泛应用。目前深度学习模型在训练过程中对硬件资源要求较高,例如采用GPU、TPU等硬件进行训练加速。但在模型部署阶段,对于复杂的深度学习模型,要想达到较快的推理速度,部署的硬件成本很高,在边缘终端上特别明显。而知识蒸馏利用较复杂的预训练教师模型,指导轻量级的学生模型训练,将教师模型的知识传递给学生网络,实现模型压缩,减少对部署平台的硬件要求,可提高模型的推理速度。

知识蒸馏的概念出自2015年Hinton发表的Distilling the Knowledge in a Neural Network论文。在分类问题中,模型的输出层采用Softmax函数预测类别概率值。模型的目标是使输出概率值最高的类别尽可能接近于实际真实目标(Hard-target),它的训练过程是对真实值求极大似然。原始数据标签通常采用独热编码。这种情况下,所有负标签被平等对待,但这与实际并不一致,例如在动物图片分类时,如果图片的真实结果是猫,而负样本老虎和蛇与猫相近程度是不同的,显然老虎与猫更像;而当模型预测为蛇时,它比预测为老虎时更差。这说明不同负样本对应的输出概率应该有区别,也就是说不同的负样本对损失函数的贡献是不同的。而软目标(Soft Target)是通过改进Softmax层输出的类别概率,对每个类别都分配概率。
软目标(Soft Target)是在Softmax函数中增加了温度参数T:
image.png

其中,zi是某个类别i的逻辑单元输出值,其值越大,说明结果属于这个类别的可能性就越大。T表示蒸馏的温度或强度,T的值越高,其输出概率分布越趋于平滑,就越放大负标签对应的信息,也就更加关注负标签。当T的值为1时,上式就变成了传统的Softmax函数。采用这种方式进行训练,对样本量要求更少,训练后的模型具有更好的泛化能力。
知识蒸馏的重点是对学生模型进行训练,其实现的过程如下:
1.通过传统方式训练比较复杂的教师模型。
2.训练学生模型,其损失分为蒸馏损失和学生损失两部分:前者借助教师模型的输出作为标签进行损失值计算,后者借助真实标签值计算损失,两个损失加权求和即可得到学生模型的总损失,从而对学生模型进行训练。
3.得到学生模型后,在推理阶段使用传统Softmax进行预测。
学生模型的训练和推理过程如图11-25所示。
在知识蒸馏之前先训练教师模型,然后基于改进的Softmax可得到教师模型已经学习的知识,并用其预测的结果作为标签与学生模型软目标输出的结果比较,计算蒸馏损失,使学生模型学习到正样本与负样本之间的关系信息。学生模型的损失除了蒸馏损失之外,还包括采用传统Softmax得到的硬目标损失,即设置T的值为1,将真实的样本标签与模型预测结果进行比较得到学生损失。
image.png

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
目录
相关文章
|
机器学习/深度学习 传感器 编解码
再谈注意力机制 | 运用强化学习实现目标特征提取
再谈注意力机制 | 运用强化学习实现目标特征提取
再谈注意力机制 | 运用强化学习实现目标特征提取
|
2月前
|
机器学习/深度学习 人工智能 自然语言处理
C++构建 GAN 模型:生成器与判别器平衡训练的关键秘籍
生成对抗网络(GAN)是AI领域的明星,尤其在C++中构建时,平衡生成器与判别器的训练尤为关键。本文探讨了GAN的基本架构、训练原理及平衡训练的重要性,提出了包括合理初始化、精心设计损失函数、动态调整学习率、引入正则化技术和监测训练过程在内的五大策略,旨在确保GAN模型在C++环境下的高效、稳定训练,以生成高质量的结果,推动AI技术的发展。
97 10
|
9月前
|
机器学习/深度学习 人工智能 运维
[ICLR2024]基于对比稀疏扰动技术的时间序列解释框架ContraLSP
《Explaining Time Series via Contrastive and Locally Sparse Perturbations》被机器学习领域顶会ICLR 2024接收。该论文提出了一种创新的基于扰动技术的时间序列解释框架ContraLSP,该框架主要包含一个学习反事实扰动的目标函数和一个平滑条件下稀疏门结构的压缩器。论文在白盒时序预测,黑盒时序分类等仿真数据,和一个真实时序数据集分类任务中进行了实验,ContraLSP在解释性能上超越了SOTA模型,显著提升了时间序列数据解释的质量。
|
5月前
|
机器学习/深度学习 自然语言处理
如何让等变神经网络可解释性更强?试试将它分解成简单表示
【9月更文挑战第19天】等变神经网络在图像识别和自然语言处理中表现出色,但其复杂结构使其可解释性成为一个挑战。论文《等变神经网络和分段线性表示论》由Joel Gibson、Daniel Tubbenhauer和Geordie Williamson撰写,提出了一种基于群表示论的方法,将等变神经网络分解成简单表示,从而提升其可解释性。简单表示被视为群表示的“原子”,通过这一分解方法,可以更好地理解网络结构与功能。论文还展示了非线性激活函数如何产生分段线性映射,为解释等变神经网络提供了新工具。然而,该方法需要大量计算资源,并且可能无法完全揭示网络行为。
56 1
|
9月前
|
机器学习/深度学习 数据采集 算法
|
9月前
|
存储 人工智能 自然语言处理
论文介绍:Mamba:线性时间序列建模与选择性状态空间
【5月更文挑战第11天】Mamba是新提出的线性时间序列建模方法,针对长序列处理的效率和内存问题,采用选择性状态空间模型,只保留重要信息,减少计算负担。结合硬件感知的并行算法,优化GPU内存使用,提高计算效率。Mamba在多种任务中展现出与Transformer相当甚至超越的性能,但可能不适用于所有类型数据,且硬件适应性需进一步优化。该模型为长序列处理提供新思路,具有广阔应用前景。[论文链接](https://arxiv.org/abs/2312.00752)
228 3
|
9月前
|
机器学习/深度学习 开发者
论文介绍:基于扩散神经网络生成的时空少样本学习
【2月更文挑战第28天】论文介绍:基于扩散神经网络生成的时空少样本学习
98 1
论文介绍:基于扩散神经网络生成的时空少样本学习
|
机器学习/深度学习 数据采集 PyTorch
手写数字识别基本思路
手写数字识别基本思路
180 0
|
机器学习/深度学习 人工智能 数据可视化
【Pytorch神经网络理论篇】 14 过拟合问题的优化技巧(一):基本概念+正则化+数据增大
【Pytorch神经网络理论篇】 14 过拟合问题的优化技巧(一):基本概念+正则化+数据增大
536 0
|
机器学习/深度学习 人工智能 PyTorch
【Pytorch神经网络理论篇】 34 样本均衡+分类模型常见损失函数
Sampler类中有一个派生的权重采样类WeightedRandomSampler,能够在加载数据时,按照指定的概率进行随机顺序采样。
463 0