【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network

简介: 【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network

知识蒸馏(Knowledge Distillation)是一种模型压缩方法,由深度学习三巨头Hinton老爷子在2015年提出。深度学习在计算机视觉、语音识别、自然语言处理等内的众多领域中均取得了令人难以置信的性能。但是,大多数模型在计算上过于昂贵,无法在移动端或嵌入式设备上运行。因此需要对模型进行压缩,且知识蒸馏是模型压缩中重要的技术之一。现如今,知识蒸馏被广泛的用于模型压缩和迁移学习当中,这篇就是知识蒸馏的开山之作,今天我也一起读一下这篇论文,学习学习。


文章的标题是Distilling the Knowledge in a Neural Network,那么说明是神经网络的知识呢?一般认为模型的参数保留了模型学到的知识,因此最常见的迁移学习的方式就是在一个大的数据集上先做预训练,然后使用预训练得到的参数在一个小的数据集上做微调(两个数据集往往领域不同或者任务不同)。例如先在Imagenet上做预训练,然后在COCO数据集上做检测。在这篇论文中,作者认为可以将模型看成是黑盒子,知识可以看成是输入到输出的映射关系。因此,我们可以先训练好一个teacher网络,然后将teacher的网络的输出结果q qq作为student网络的目标,训练student网络,使得student网络的结果 p  接近 q,这就是论文的基本思想。


63646a54ebd4465cb659580afe15ca30.jpg


首先可以思考一个问题,知识蒸馏和从头训练一个网络有什么不同


这样和从头训练一个模型有什么不一样?


显然,模型越复杂,理论搜索空间越大。但是,如果我们假设较小的网络也能实现相同(甚至相似)的收敛,那么教师网络的收敛空间应该与学生网络的解空间重叠。


不幸的是,仅凭这一点并不能保证学生网络收敛在同一点。学生网络的收敛点可能与教师网络有很大的不同。但是,如果引导学生网络复制教师网络的行为(教师网络已经在更大的解空间中进行了搜索),则其预期收敛空间会与原有的教师网络收敛空间重叠。

e269236d0bae8a47d7a58d94a058a9c6.png


摘要 Abstract


要提高几乎所有机器学习算法的性能,一个非常简单的方法是在同一个数据上训练许多不同的模型并集成平均它们的预测效果。不幸的是,使用集成模型进行预测是很麻烦的,而且可能计算成本太高,不能部署到大量用户,特别是当单个模型是大型神经网络时。有研究表示,将集成模型中的知识压缩到易于部署的单个模型是有可能的,我们使用不同的压缩技术进一步发展了这种方法。

 我们在MNIST上取得了一些令人惊喜的结果,并且我们表明,通过将集成模型的知识提炼到单个模型中可以显著改进大量应用于商业系统的语音模型。我们还引入了一种新型的由一个或多个完整模型和许多能够学习区分完整模型会混淆的细粒度类的“专家模型”。不同于“专家模型”的混合,这些“专家模型”可以快速地并行训练。


介绍 Introduction


许多昆虫的幼年形态是最适合从环境中汲取能量和营养的,而成虫形态则完全不同,更适合旅行和繁殖等不同需求。昆虫的类比表明我们可以训练非常复杂的模型,其易于从数据中提取出结构。这个复杂的模型可以是独自训练模型的集成,也可以是一个用强大正则器如d r o p o u t dropoutdropout训练的单个大模型。一旦复杂模型训练完毕,之后我们可以使用一种不同的训练方式,称之为“蒸馏”,将知识从复杂的模型(称之为t e a c h e r teacherteacher模型)转移到更易于部署的小模型(称之为s t u d e n t studentstudent模型)中。


2b1f38c35ea73420e2936d8f1af9959e.png


对于t e a c h e r teacherteacher模型,其能够学习区分大量的类别,正常情况下,训练目标是最大化正确类别的平均对数概率,但这种**学习的副作用是训练的模型会将概率分配给错误的类别上,虽然这些概率值可能很小,但一些错误类别比其他错误类别的概率值大很多。**比如:一辆宝马车的照片可能有很小的概率被误认为垃圾车,但是仍比被误认为一根萝卜的概率大出很多。在错误类别上的相对概率可以反映出模型是如何进行泛化的,这也是模型学到的知识,教师网络可能可以从相关性的概率告诉学生如何去泛化。


通常上来说,我们认为模型学习到的参数代表着“知识”,这是无法迁移的,如果将模型参数定义为“知识”是无法进行操作的。论文中提出可以利用t e a c h e r teacherteacher模型的类别概率作为 “软目标” 用于训练s t u d e n t studentstudent模型,概率中含有许多隐式的信息。


**当soft target具有高熵值,在训练每一个样本时软目标能够提供比hardtragetstudent模型的roundtruth)更多的信息并且训练每一个样本时的梯度差异更小。因此,与teacher模型相比,student模型训练数据要少得多,使用的学习率也高得多。**这一部分其实也是比较容易思考的,比如我们的硬目标只有一个数据,比如(0,0,1),我们只能从中学习正确的目标,但是不能得到其他的数据,但是对于我们的teacher模型来说,他输出的是我们的软目标,里面包含着概率,比如(0.1,0.3,0.6),这个熵明显更高,包含着更多的信息,也能获取更多的知识。


a053659407a961179c43034e5c87fd44.png


  • Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。
  • Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。


如在MNIST数据集中做手写体数字识别任务,假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率会比其他负标签类别高;而另一个"2"更加形似"7",则这个样本分配给"7"对应的概率会比其他负标签类别高。这两个"2"对应的Hard-target的值是相同的,但是它们的Soft-target却是不同的,由此我们可见Soft-target蕴含着比Hard-target更多的信息。



66e134499d0b5da6bd282647ca24b713.png


**这是非常有价值的信息,它在数据上定义了丰富的相似性结构(例如,它说明了哪个版本的2 看起来像3 ,哪个像7 ),但它对迁移阶段的交叉熵代价函数的影响很小,因为这些概率值接近零。**之前研究的作者,解决这个的方法是利用logits对未经过oftmax函数的值,而不是经过softmax函数之后的概率值,然后将teacher的logits值和sstudent的logits值的平方差作为最小化目标。


在介绍知识蒸馏方法之前,首先得明白什么是Logits。我们知道,对于一般的分类问题,比如图片分类,输入一张图片后,经过DNN网络各种非线性变换,在网络最后Softmax层之前,会得到这张图片属于各个类别的大小数值z i,某个类别的z i 数值越大,则模型认为输入图片属于这个类别的可能性就越大。什么是Logits? 这些汇总了网络内部各种信息后,得出的属于各个类别的汇总分值 z i ,就是Logits,i代表第i个类别,z i 代表属于第i类的可能性。因为Logits并非概率值,所以一般在Logits数值上会用Softmax函数进行变换,得出的概率值作为最终分类结果概率。Softmax一方面把Logits数值在各类别之间进行概率归一,使得各个类别归属数值满足概率分布;另外一方面,它会放大Logits数值之间的差异,使得Logits得分两极分化,Logits得分高的得到的概率值更偏大一些,而较低的Logits数值,得到的概率值则更小。



50998707f6e9e82da245c221caa32616.png



**论文中,我们提出了一个通用的解决方案,叫做“蒸馏”,是提高oftmax最终值的温度,直到teacher模型产生一个合适的软目标集。**当训练student 模型来匹配这些软目标时,我们使用同样高的温度。 稍后说明,之前研究中直接slogits实际上只是蒸馏的一种特殊情况。我们还可以用这种方法在未标记的数据集中训练小模型,也可以用原始的数据集,这样我们可以进行一个扩充数据集操作。我们可以从下图看的出来,我们可以从teacher模型中得到软目标,然后对student模型进行训练。


1fe0799103314fb58b75405e9145e92c.png

蒸馏 Distillation


神经网络通常使用softmax输出层来生成类的概率,然后这一部分会对每一个l o g i t s生成对应的概率 q i


image.png


其中 T TT 表示温度,正常情况下设为 1.0 1.01.0 ,使用更高的 T TT 可以生成更加平滑的类概率分布。


最简单的蒸馏形式是通过在迁移集训练上student模型将知识迁移到student模型中,在迁移集中得每个样本使用由将温度 T 设置得 很高的teacher 模型生成软目标分布。在训练student模型时使用相同的高温,但在训练结束后,其使用温度 T = 1.0 在训练 student模型时,我们使用两个不同目标函数的加权平均,形式如下:


image.png


其中CrossEntropy ( ∗ ) (*)(∗) 为交叉熵函数, y s 表示student模型的预测结果, y s 表示teacher模型的预测结果, y yy 是student模型的真实 标签。第一个交叉熵函数是student模型预测结果与软目标的交叉熵,其使用与训练该软目标的teacher中一样的温度。第二个目标函数 是student模型与真实标签的交叉熵,其温度设置为 1.0 1.01.0 。

9f8245cebb9344668f31112349dd460c.png



我们发现最好的结果是在第二个目标函数上使用较低的权值(即: α \alphaα 取小值) 。由于软目标产生的梯度缩放了image.png ,因而在同时使 用软目标和硬目标时,将软目标的梯度乘以image.png是非常重要的。这确保了在实验过程中,如果用于蒸馏的温度发生改变,那么硬目标和软 目标的相对贡献大致保持不变。


总结一下


  • 应用蒸馏技术的大致流程如下图所示:


3b95ada83c1443c4a2b2236aadaf59d3.png

将student模型在训练软目标时,要在高温环境下进行,即T取一个较大的值,同时与它做交叉熵的teacher模型的输出也要在相同的温度情况下


student的硬目标则在T = 1 情况下进行,并且总的优化目标是软目标和硬目标的加权和


蒸馏的一种特殊形式:直接Matching Logits


直接Matching Logits指的是,,直接使用softmax层的输入logits(而不再是输出)作为Soft target, 需要最小化的目标函数是Teacher模型和Student模型的logits之间的平方差,


image.png对 z i 求梯度可得:

image.png

再看一般蒸馏中image.png  求梯度可得:

image.png

image.png


此时, 假设 Logits 在每个样本上是零均值的, 则进一步近似:


image.png

可见, 经过Softmax的蒸馏方式和直接Matching Logits的方式, 当温度 T → ∞  时Soft-target损 失函数部分是等价的, 即 Matching Logits是一般知识蒸馏方法的一种特殊形式。


对于蒸馏温度 T 的理解


在知识蒸馏中,需要使用高温将知识“蒸馏”出来,但是如何调节温度T呢,温度的变化会产生怎样的影响呢?



f916a541acbd1890ecccedd5366a2a62.png


温度 T TT 有这样几个特点:


原始的s o f t m a x softmaxsoftmax函数是 T = 1 T=1T=1 时的特例; T < 1 T<1T<1 时, 概率分布比原始更 “陡峭”, 也就是说, 当 T → 0 T \rightarrow 0T→0 时, S o f t m a x SoftmaxSoftmax 的输出值会接近于 H a r d − t a r g e t ; Hard-target;Hard−target; T > 1 T>1T>1 时, 概率分布比原始更“平 缓"。

随着 T TT 的增加, S o f t m a x SoftmaxSoftmax 的输出分布越来越平缓, 信息熵会越来越大。温度越高, $softmax $上各个值的分布就越平均, 思考极端情况, 当 T = ∞ T=\inftyT=∞, 此时s o f t m a x softmaxsoftmax的值是平均分布的。

不管温度 T TT 怎么取值, S o f t − t a r g e t Soft-targetSoft−target都有忽略相对较小的 p i p_{i}p

i

 ( T e a c h e r TeacherTeacher模型在温度为 T \mathrm{T}T 时 s o f t m a x softmaxsoftmax输出在第 i ii 类上的值) 携带的信息的倾向。

温度的高低改变的是S t u d e n t StudentStudent模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,S t u d e n t StudentStudent模型会相对更多地关注到负标签。


实际上,负标签中包含一定的信息,尤其是那些负标签概率值显著高于平均值的负标签。但由于Teacher模型的训练过程决定了负标签部分概率值都比较小,并且负标签的值越低,其信息就越不可靠。因此温度的选取需要进行实际实验的比较,本质上就是在下面两种情况之中取舍:


当想从负标签中学到一些信息量的时候,温度T应调高一些;

当想减少负标签的干扰的时候,温度T应调低一些;

总的来说,T 的选择和Student模型的大小有关,Student模型参数量比较小的时候,相对比较低的温度就可以了。因为参数量小的模型不能学到所有Teacher模型的知识,所以可以适当忽略掉一些负标签的信息。


最后,在整个知识蒸馏过程中,我们先让温度T升高,然后在测试阶段恢复“低温“(T = 1 ),从而将原模型中的知识提取出来,因此将其称为是蒸馏,实在是妙啊。


Experiment 实验


MNIST


在该数据集上,首先在60000个训练样本上训练了一个带有两个隐藏层 (每层有1200个单元) 的teacher网络,该网络使用 dropout和 weight-constraints的正则化方法。另外输入的图片在任意方向上抖动了两个像素。该网络取得了67个测试错误的结果,而一个带有 两个隐藏层 (每层有800个单元) 的student网络并不带正则方法取得了146个测试错误的结果。但如果这个student网络添加了由温度 设置为20的大网络软目标任务,则它的的测试错误结果为 74 个。这表明了软目标能够将大量的知识迁移到student网络中,其中包括了 从转译数据中学习到的如何泛化的知识,及时迁移数据集中不包含任何的转译。

当带有两层隐藏层的student网络中每个隐藏层中的单元数据超过 300 个时,所有高于8的温度设置都能得到相似的结果。但是将隐 藏层单元量急剧减小为 30 个时,温度在2.5-4之间的效果优于这个范围之外的温度设置。


我们尝试将迁移数据集中数字 3 的样本移除。因此从 student网络的观点看,数字 3 是一个从末见过的神秘数字。即使是这样, student网络也只造成了206个测试错误,其中133个是来自测试集中的1010个数字 3 。


speech recognition


第二个实验是在speech recognition领域,使用不同的参数训练了10个DNN,对这10个模型的预测结果求平均作为emsemble的结果,相比于单个模型有一定的提升。然后将这10个模型作为teacher网络,训练student网络。得到的Distilled Single model相比于直接的单个网络,也有一定的提升,结果见下表:


2020112619223339.png


Training ensembles of specialists on very big dataset


在这一部分呢,实际上是与我们的知识蒸馏是无关的,和集成模型有点关系,他提出了一种专才模型的方法。


类似于我们可以用一部分模型来判断特定的一部分相似的类别,或者细粒度的类别,比如拉布拉多犬和田园犬以及各种犬科动物等等,专才模型对他们进行一个训练,然后多个专才模型进行训练,可以并行训练,各自互不干扰,所以训练速度是非常快的。


本文中作者是在Google自己的数据集上进行训练的,用了数据并行和模型并行的方法,得到的结果也是非常不错,我觉得在某一方面有特别多的启发,不过这一部分我觉得深究还是需要细看论文,就不在知识蒸馏领域讲集成学习了。


Discussion


论文展示了蒸馏在将一个集成模型或一个大的正则化模型中的知识迁移到一个小的student模型中效果良好。在MNIST任务 上,即使用于训练student模型的迁移集缺少一个或多个类的例子,蒸馏的效果也非常好。对于一个应用在安卓语音搜索中的深度声学模 型,我们表明训练一个集成的深度模型的几乎所有性能提高都能够提炼到一个大小相同的更易于部署的单一神经网络上。


28792059d35445449ebc01d3179a9dc1.png

总结


本文是知识蒸馏的早期文章之一,知识蒸馏是将teacher模型 (通常的大模型或集成模型) 的知识迁移到student模型(通常是小模型),具体做法是将teacher模型的学习到的类分类概率分布作为student模型学习的软目标,student模 型最终的目标函数是软目标与自身训练的硬目标的加权和。作者在 mnist任务和声学模型上进行实验,结果显示知识蒸馏的效果良好,能够将知识进行很好的迁移。文中作者还介绍了一种集成专家模型,能够对易于混淆的类别进行区分,并具有良好的并行性。本文知识蒸 馏主要有以下几个要点需要注意:


软目标需要在高温情况下进行 ( T TT 取较大值),而硬目标则为常规的 softmax 函数 ( T = 1.0 )

软目标和硬目标的加权和作为最终优化目标时,软目标梯度更新时需要乘以 image.png

知识蒸馏有一定的防止过拟合的作用

知识蒸馏起到了与网络剪枝一样的效果,压缩了模型的大小并且保持了模型的性能


参考


https://jishuin.proginn.com/p/763bfbd57a41

https://www.bilibili.com/video/BV1N44y1n7mU/

https://blog.csdn.net/ZY_miao/article/details/110182948

相关文章
|
机器学习/深度学习 编解码 自然语言处理
Vision Transformer 必读系列之图像分类综述(二): Attention-based(上)
Transformer 结构是 Google 在 2017 年为解决机器翻译任务(例如英文翻译为中文)而提出,从题目中可以看出主要是靠 Attention 注意力机制,其最大特点是抛弃了传统的 CNN 和 RNN,整个网络结构完全是由 Attention 机制组成。为此需要先解释何为注意力机制,然后再分析模型结构。
864 0
Vision Transformer 必读系列之图像分类综述(二): Attention-based(上)
|
5月前
|
机器学习/深度学习 算法
【博士每天一篇文献-综述】A wholistic view of continual learning with deep neural networks Forgotten
本文提出了一个整合持续学习、主动学习(active learning)和开放集识别(open set recognition)的统一框架,基于极端值理论(Extreme Value Theory, EVT)的元识别方法,强调了在深度学习时代经常被忽视的从开放集识别中学习识别未知样本的教训和主动学习中的数据查询策略,通过实证研究展示了这种整合方法在减轻灾难性遗忘、数据查询、任务顺序选择以及开放世界应用中的鲁棒性方面的联合改进。
44 6
|
5月前
|
机器学习/深度学习 存储 算法
【博士每天一篇文献-综述】Continual lifelong learning with neural networks_ A review
这篇综述论文深入探讨了神经网络在终身学习领域的研究进展,包括生物学启发的概念、终身学习方法的分类与评估,以及未来研究方向,旨在解决神经网络在学习新任务时如何避免灾难性遗忘的问题。
58 2
|
5月前
|
机器学习/深度学习 存储 人工智能
【博士每天一篇文献-综述】Brain-inspired learning in artificial neural networks a review
这篇综述论文探讨了如何将生物学机制整合到人工神经网络中,以提升网络性能,并讨论了这些整合带来的潜在优势和挑战。
51 5
|
8月前
|
Python
[Knowledge Distillation]论文分析:Distilling the Knowledge in a Neural Network
[Knowledge Distillation]论文分析:Distilling the Knowledge in a Neural Network
49 1
|
8月前
|
机器学习/深度学习
[Highway]论文实现:Highway Networks
[Highway]论文实现:Highway Networks
43 2
|
机器学习/深度学习 存储 自然语言处理
论文推荐:Rethinking Attention with Performers
重新思考的注意力机制,Performers是由谷歌,剑桥大学,DeepMind,和艾伦图灵研究所发布在2021 ICLR的论文已经超过500次引用
149 0
|
机器学习/深度学习 编解码 固态存储
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(下)
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(下)
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(下)
|
机器学习/深度学习 存储 编解码
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(上)
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(上)
|
机器学习/深度学习 存储 传感器
Unsupervised Learning | 对比学习——13篇论文综述
Unsupervised Learning | 对比学习——13篇论文综述
2129 0
Unsupervised Learning | 对比学习——13篇论文综述