知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍

简介: 知识蒸馏 | 知识蒸馏的算法原理与其他拓展介绍

框架


image.png

1)第一个方向是把一个已经训练好的臃肿的网络进行瘦身

权值量化:把模型的权重从原来的32个比特数变成用int8,8个比特数来表示,节省内存,加速运算

剪枝:去掉多余枝干,保留有用枝干。分为权重剪枝和通道剪枝,也叫结构化剪枝和非结构化剪枝,一根树杈一根树杈的剪叫非结构化剪枝,也可以整层整层的剪叫结构化剪枝。

2)第二个方向是在设计时就考虑哪些算子哪些设计是轻量化的

轻量化网络有很多需要考虑的内容:参数量、计算量

3)第三个方向是在数值运算的角度来加速各种算子的运算

比如im2col+GEMM,就是把卷积操作转成矩阵操作,矩阵操作是很多算法库里内置的功能,比如py,tf和matlab都有底层的加速到极致的矩阵运算的算子

4)第四个方向就是硬件部署

用英伟达的TensorRT库,把模型压缩成中间格式,部署在Jetson开发板上;Tensorflow-slim和Tensorflow-lite是tensorflow轻量化的生态;因特尔的openvino;FPGA集成电路也可以部署人工智能算法


1. 知识蒸馏的算法原理


1.1 知识的表示与迁移

image.png

把左边的马图像喂给分类模型,会有很多类别,每个类别识别出一个概率,训练网络时,我们只会告诉网络,这张图片是马,其余是驴是汽车的概率都是0,这个就是hard targets,用hard targets训练网络,但这就相当于告诉网络,这就是一匹马,不是驴不是车,而且不是驴不是车的概率是相等的,这是不科学的。若是把马的图片喂给已经训练好的网络里面,网络给出soft targets这个结果,是马的概率为0.7,为驴的概率为0.25,为车的概率是0.05,所以soft targets就传递了更多的信息


总结:Soft Label包含了更多“知识”和“信息,像谁,不像谁,有多像,有多不像,特别是非正确类别概率的相对大小(驴和车)


此外还引入蒸馏温度T,把原来比较硬的soft targets变的更软,更软的soft targets去训练学生网络,那些非正确类别概率的信息就暴露的越彻底,相对大小的知识就暴露出来,让学生网络去学

image.png

  • T为1,就是原softmax函数,softmax本来就是把每个类别的logic强行变成0-1之间的概率,并且求和为1,是有放大差异的功能,如果logic高一点点,经过softmax,都会变的很高。
  • T越小,非正确类别的概率相对大小的信息就会暴露的更明显;T越大,曲线就会变得更soft,高的概率给降低,低的概率会变高,贫富差距就没有了。


关于对softmax的温度测试同样可以见我另外的一篇博客,其中包含对温度改变后logits的变化。


1.2 训练流程

9c40f7b9824a4b7096e8eb89e86a9207.png

总的损失分为两个部分:

1)Distillation loss:一部分是来自于利用温度T进行软分类,也就是与教师网络的输出结果进行交叉熵损失

2)Student loss:另一部分是进行原始的硬分类,也就是与真实标签进行交叉熵损失

总的损失就是以上两个损失的加权和,Distillation loss与Student loss的具体计算方法见上图所示。

image.png

具体来说,知识蒸馏的流程是,一方面让Student模型去拟合Teacher网络输出的软标签信息,从而让Student网络可以学校到一些潜在的语义信息归纳Teacher网络的经验;另一方面,让Student网络与真实的硬标签做一个交叉熵损失了解真实数据的差异,两种损失通过一个权重相加形成总损失。


1.3 推理过程

image.png

当训练完成后,推理过程中就不需要温度为T去进行测试了,直接对网络输出的logit进行softmax进行预测即可。Teacher网络是比较臃肿的,Student网络是比较轻巧的。也就是可以利用一个比较轻巧的模型学习到一个质量比较好但是参数量比较大的模型,然后就可以部署在嵌入式设备中。


1.4 KD与Labe Smoothing的区别

image.png

Label Smoothing是为了模型太过自信,为此给予其他类别也有一点的分数,也就是杜绝了模型读cat类别的100%预测,使其拥有一点回流余地,但是很明显Label Smoothing会丢失很多的信息,其不能判断类别之间的关系,不能判断类别之间有多像与有多不像,所以Labe Smoothing是没有Soft label进行蒸馏的效果好的。

image.png


2. 知识蒸馏的应用场景


image.png

1)无监督的训练

将海量没有标签的数据集输入到已经训练好的Teacher网络中,获得的soft label就可以指导训练Student网络,这是无监督的一种方式。

2)Few Shot/Zero Shot

由于Teacher网络将经验传授了给Student网络,这使得就算训练集中没有测试集中要出现的数据,也就是Student网络可能从来没有见过某一类的数据,但在测试的过程中仍然可以对其进行正确分类。又或者只给Student网络提供少量的某种类型的数据,还是可以在测试过程中对这种少样本的数据集进行正确的分类。

3)防止过拟合

对于大模型的训练很容易会出现过拟合的情况,所以需要设置一些正则化Dropout或者是数据增强Data Aug来增加模型的泛化能力。而对于Student网络来说,由于需要部署所以是轻量级,模型参数肯定比较小,所以训练Student网络不容易出现过拟合的情况。

4)模型压缩

知识蒸馏的最重要的目的就是为了让一个轻量级的模型可以获得重量级模型的经验,从而可以轻易的部署在移动端或者是嵌入式端中,而且这种soft label的训练方式可以有效的指导Student网络。


3. 知识蒸馏的背后机理


question:为什么知识蒸馏的效果这么好,这里有一个有说服力的解释:

1)解释一

image.png

就是说,对于大模型来说,其可解的空间可能会比较大(比如Teacher网络的可解空间是绿色区域),更容易找到一个比较好的解;而对于一个小模型,其可解的空间相比大模型来说会比较小(比如Student网络的可解空间是蓝色区域),那么其找一个比较好的解可能比较困难而且也其可解区域不完全包括大模型的解集,优化的方向也难以控制。

而Student网络的作用就体现出来了,大模型的红色解集区域会慢慢引导小模型的黄色解集区域到一个比较靠近大模型解集附近的也橙色区域,这就是知识蒸馏的一个知识引导迁移的作用,让小模型获得一个更好的解,毕竟一般来说大模型的解肯定要比小模型的解要好的。


2)解释二

image.png

在Bert中也用到了知识蒸馏技术,其中它也给了一个有说服力的解释:在训练一个大型的语言模型时,会训练出很多比较容易的特征;而迁移学习与微调就是把这些冗余的特征精选出一些有用的特征来进行泛化和迁移。而如何获取揽括比较多的有用特征,就只能是模型参数量大一点,显而易见对于小模型来说其参数量是比较小的,所以其很难从浩如烟海的数据中找出比较有用的特征,所以队大模型进行知识蒸馏就是告诉小模型哪里知识的冗余的哪些特征是有用的。

这样通过蒸馏的技术,相当于是把大模型进行了一个精炼,让小模型知道大模型的哪些参数哪些部位是有用的,从而达到了模型压缩的效果。


4. 知识蒸馏的发展趋势


image.png

主要方向:

1)教学相长

之前一直是大模型老师网络来教导学生网络,但是其实也可以通过学生网络反过来指导教师网络,使得教师网络可以进一步成长。两个网络互帮互助,相互学习,我觉得这种模式其实是互学习的一种。

2)助教,多个老师/同学

在刚刚的角色中只有一个老师网络与一个学生网络,但是可以引入多个老师多个学生的模式,甚至是进入一个助教的模式。也就是不需要全部问题都问实力强厚的老师网络,可以先从助教网络来吸取部分经验,再让老师网络占总体方向的引导,也就是分工来指导学生网络。

3)多模态,知识图谱,预训练大模型的知识蒸馏

image.png

4)知识的表示(中间层),数据集蒸馏,对比学习

对于刚刚所展示的知识蒸馏,其实用的是网络输出最后一层的soft target表示出来的,那么其实网络的中间层也有尝试的解剖出来进行知识蒸馏,整个中间层的结果可以是feature map,可以是feature map构建出来的自注意力图,也可以是层之间的关系。几篇例子:

  • Attention Transfer:用中间层的feature map来进行知识蒸馏
  • Channel-wise knowledge distillation for dense prediction:用中间层的注意力图来进行知识蒸馏
  • Contrastive Representation Distillation:用对比学习来进行知识蒸馏

image.png


下面分别对几种知识蒸馏的知识表示进行拓展展示:

  • Response-Based Knowledge:把预测结果作为知识的表示

image.png

  • Feature-Based Knowledge:把中间层作为知识的表示

image.png

  • Relation-Based Knowledge:把注意图之间的关系作为知识的表示

image.png

彩蛋:代码库工具

ps:在视频的随后,up还贴了几个知识蒸馏的代码块工具,这里我顺便贴出来

1)MMRazor:OpenMMLab模型压缩工具(github/open-mmlab/mmrazor)

2)MMDeploy:OpenMMLab模型转换与部署工具箱(github/open-mmlab/mmdeploy)

3)RepDistiller:12个SoTA知识蒸馏算法的Pytorch复现(github/Hobbitlong/RepDistiller)


参考资料:

1)https://www.bilibili.com/video/BV1N44y1n7mU

2)https://www.bilibili.com/read/cv15391720?from=note


目录
相关文章
机器学习/深度学习 算法 自动驾驶
578 0
|
3月前
|
机器学习/深度学习 算法 搜索推荐
从零开始构建图注意力网络:GAT算法原理与数值实现详解
本文详细解析了图注意力网络(GAT)的算法原理和实现过程。GAT通过引入注意力机制解决了图卷积网络(GCN)中所有邻居节点贡献相等的局限性,让模型能够自动学习不同邻居的重要性权重。
515 0
从零开始构建图注意力网络:GAT算法原理与数值实现详解
|
4月前
|
机器学习/深度学习 算法 文件存储
神经架构搜索NAS详解:三种核心算法原理与Python实战代码
神经架构搜索(NAS)正被广泛应用于大模型及语言/视觉模型设计,如LangVision-LoRA-NAS、Jet-Nemotron等。本文回顾NAS核心技术,解析其自动化设计原理,探讨强化学习、进化算法与梯度方法的应用与差异,揭示NAS在大模型时代的潜力与挑战。
1006 6
神经架构搜索NAS详解:三种核心算法原理与Python实战代码
|
4月前
|
传感器 算法 定位技术
KF,EKF,IEKF 算法的基本原理并构建推导出四轮前驱自主移动机器人的运动学模型和观测模型(Matlab代码实现)
KF,EKF,IEKF 算法的基本原理并构建推导出四轮前驱自主移动机器人的运动学模型和观测模型(Matlab代码实现)
154 2
|
4月前
|
算法
离散粒子群算法(DPSO)的原理与MATLAB实现
离散粒子群算法(DPSO)的原理与MATLAB实现
201 0
|
5月前
|
机器学习/深度学习 人工智能 编解码
AI视觉新突破:多角度理解3D世界的算法原理全解析
多视角条件扩散算法通过多张图片输入生成高质量3D模型,克服了单图建模背面细节缺失的问题。该技术模拟人类多角度观察方式,结合跨视图注意力机制与一致性损失优化,大幅提升几何精度与纹理保真度,成为AI 3D生成的重要突破。
452 0
|
5月前
|
算法 区块链 数据安全/隐私保护
加密算法:深度解析Ed25519原理
在 Solana 开发过程中,我一直对 Ed25519 加密算法 如何生成公钥、签名以及验证签名的机制感到困惑。为了弄清这一点,我查阅了大量相关资料,终于对其流程有了更清晰的理解。在此记录实现过程,方便日后查阅。
565 1
|
6月前
|
消息中间件 存储 缓存
zk基础—1.一致性原理和算法
本文详细介绍了分布式系统的特点、理论及一致性算法。首先分析了分布式系统的五大特点:分布性、对等性、并发性、缺乏全局时钟和故障随时发生。接着探讨了分布式系统理论,包括CAP理论(一致性、可用性、分区容错性)和BASE理论(基本可用、软状态、最终一致性)。文中还深入讲解了两阶段提交(2PC)与三阶段提交(3PC)协议,以及Paxos算法的推导过程和核心思想,强调了其在ZooKeeper中的应用。最后简述了ZAB算法,指出其通过改编的两阶段提交协议确保节点间数据一致性,并在Leader故障时快速恢复服务。这些内容为理解分布式系统的设计与实现提供了全面的基础。
|
6月前
|
存储 算法 安全
Java中的对称加密算法的原理与实现
本文详细解析了Java中三种常用对称加密算法(AES、DES、3DES)的实现原理及应用。对称加密使用相同密钥进行加解密,适合数据安全传输与存储。AES作为现代标准,支持128/192/256位密钥,安全性高;DES采用56位密钥,现已不够安全;3DES通过三重加密增强安全性,但性能较低。文章提供了各算法的具体Java代码示例,便于快速上手实现加密解密操作,帮助用户根据需求选择合适的加密方案保护数据安全。
424 58