[Knowledge Distillation]论文分析:Distilling the Knowledge in a Neural Network

简介: [Knowledge Distillation]论文分析:Distilling the Knowledge in a Neural Network

论文:Distilling the Knowledge in a Neural Network

作者:Geoffrey Hinton, Oriol Vinyals, Jeff Dean

时间:2015

一、完整代码

这里我们使用python代码进行实现

# 完整代码在这里
# 就是一下softmax
# 有时间再弄

二、论文解读

2.1 介绍

使用一系列模型预测概率的平均值即软投票机制能显著改善模型的性能,但是部署起来是比较不方便的:因为要预测很多的模型结果,再求平均;论文中提到可以把所有模型预测结果的平均值部署在一个模型里面,然后利用这一个模型来预测,这样就可以大大减少部署的难度,这种方法被称为Knowledge Distillation,即知识蒸馏

在知识蒸馏中,我们不需要关心参数数量和大小的变化,我们只需要关心经过这一系列的参数得到的结果变化,如果参数变少,一般来说100%复刻这个结果是很难的;但是我们可以以一定的比例如80%去还原当时的结果,尽管可能得到错误答案,但是错误答案的相对误差可以稍微控制;错误答案的相对误差告诉了我们很多关于繁琐的模型是如何泛化的。例如,一个宝马的形象可能被误认为垃圾车的可能性很小,但这个错误仍然比误认为胡萝卜的可能性大很多倍。

人们普遍认为,用于培训的目标函数应该尽可能接近地反映用户的真实目标。尽管如此,当真正的目标是很好地推广到新数据时,模型通常被训练以优化训练数据的性能。显然,训练模型进行泛化良好会更好,但这需要关于正确的泛化方法的信息,而这些信息通常是不可用的。然而,当我们将知识从大模型中提取出来到小模型中时,我们可以训练小模型以与大模型相同的方式进行泛化。如果繁琐的模型概括,例如,它是一个大型的平均不同的模型,一个小模型训练推广以同样的方式通常会做更好的测试数据比一个小模型训练的正常方式在相同的训练集用于训练集成。

将繁琐模型的泛化能力转移到小模型的一个明显方法是使用麻烦模型产生的类概率作为训练小模型的“软目标”。在这个转移阶段,我们可以使用相同的训练集或一个单独的“转移”集。当繁琐的模型是一个更简单的模型的大型集合时,我们可以使用它们各自的预测分布的算术或几何平均值作为软目标。当软目标高熵,他们提供更多的信息比硬目标和更少的方差之间的梯度训练情况下,所以小模型通常可以训练的数据比原始繁琐的模型和使用更高的学习率。

2.2 Distillation

在多分类问题上,神经网路依赖于softmax产生各个类别的概率,其中T是一个参数可以让输出概率变得平滑;

T越大,输出的概率越平滑;

在最简单的蒸馏形式中,知识通过在转移集上训练模型并在传输集中的每个情况下使用软目标分布来转移到蒸馏模型,该分布是通过在其softmax中使用高T的原模型或者原模型集合产生的;我们可以在在训练蒸馏模型时使用相同的T,但经过训练后,把T变为1;

当我们知道输入的正确输出时,我们可以利用对目标函数简单加权的方式去构造最终的目标函数,第一个目标函数是与软目标的交叉熵,这个交叉熵是用与蒸馏模型的softmax相同的T来生成软目标来计算的。第二个目标函数是具有正确标签的交叉熵。这是用蒸馏模型的softmax中完全相同的类来计算的,但T为1;因为在预测的时候T便是1;

对第一个目标函数求导:

T很大的时候,我们有:

image.png 的时候,我们又有:

所以,在高T,同时 image.png 的时候,蒸馏的本质相当于如下:

image.png

在较低的T下,蒸馏模型几乎不去关心那些比平均数更小的负值(平均数为0),这是潜在的优势,因为这些数几乎不受用于训练模型集合的代价函数的限制,因此它们可能非常有噪声;另一方面,那些很小的负值可能会传递关于由模型集合所获得的知识的有用信息。其中哪一种影响占主导地位是一个经验问题;我们表明,当蒸馏的模型太小,无法捕获繁琐模型中的所有知识时,不大不小的T效果最好,这强烈表明忽略大的负对数是有用的;

2.3 结果

原模型和原模型集合可以部署在一个小的蒸馏模型中,并且准确性可观:

利用soft targets即软投票机制可以达到regularization即防止过拟合的效果;

可以利用部分模型在部分类中的高准确率提高权重进而提高模型的准确度;或者对一些表现非常好的模型,给予其较高的T

三、整体总结

蒸馏可以很好地将知识从一个集成或从一个大的高度正则化的模型转移到一个更小的蒸馏模型中;


目录
相关文章
|
机器学习/深度学习 搜索推荐 算法
Learning Disentangled Representations for Recommendation | NIPS 2019 论文解读
近年来随着深度学习的发展,推荐系统大量使用用户行为数据来构建用户/商品表征,并以此来构建召回、排序、重排等推荐系统中的标准模块。普通算法得到的用户商品表征本身,并不具备可解释性,而往往只能提供用户-商品之间的attention分作为商品粒度的用户兴趣。我们在这篇文章中,想仅通过用户行为,学习到本身就具备一定可解释性的解离化的用户商品表征,并试图利用这样的商品表征完成单语义可控的推荐任务。
23843 0
Learning Disentangled Representations for Recommendation | NIPS 2019 论文解读
|
4月前
|
机器学习/深度学习 算法 TensorFlow
【文献学习】Analysis of Deep Complex-Valued Convolutional Neural Networks for MRI Reconstruction
本文探讨了使用复数卷积神经网络进行MRI图像重建的方法,强调了复数网络在保留相位信息和减少参数数量方面的优势,并通过实验分析了不同的复数激活函数、网络宽度、深度以及结构对模型性能的影响,得出复数模型在MRI重建任务中相对于实数模型具有更优性能的结论。
40 0
【文献学习】Analysis of Deep Complex-Valued Convolutional Neural Networks for MRI Reconstruction
|
机器学习/深度学习 自然语言处理 算法
【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network
【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network
【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network
|
机器学习/深度学习 算法 数据处理
Stanford 机器学习练习 Part 3 Neural Networks: Representation
从神经网络开始,感觉自己慢慢跟不上课程的节奏了,一些代码好多参考了别人的代码,而且,让我现在单独写也不一定写的出来了。学习就是一件慢慢积累的过程,两年前我学算法的时候,好多算法都完全看不懂,但后来,看的多了,做的多了,有一天就茅塞顿开。所有的困难都是一时的,只要坚持下去,一切问题都会解决的。没忍住发了点鸡汤文。
35 0
|
机器学习/深度学习 算法
Keyphrase Extraction Using Deep Recurrent Neural Networks on Twitter论文解读
该论文针对Twitter网站的信息进行关键词提取,因为Twitter网站文章/对话长度受到限制,现有的方法通常效果会急剧下降。作者使用循环神经网络(recurrent neural network,RNN)来解决这一问题,相对于其他方法取得了更好的效果。
115 0
|
机器学习/深度学习 PyTorch 测试技术
SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation 论文解读
我们提出了SegNeXt,一种用于语义分割的简单卷积网络架构。最近的基于transformer的模型由于在编码空间信息时self-attention的效率而主导了语义分割领域。在本文中,我们证明卷积注意力是比transformer中的self-attention更有效的编码上下文信息的方法。
413 0
|
机器学习/深度学习 编解码 固态存储
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(下)
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(下)
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(下)
|
机器学习/深度学习 存储 编解码
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(上)
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
【论文泛读】轻量化之MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications(上)
|
机器学习/深度学习
Re22:读论文 HetSANN An Attention-based Graph Neural Network for Heterogeneous Structural Learning
Re22:读论文 HetSANN An Attention-based Graph Neural Network for Heterogeneous Structural Learning
Re22:读论文 HetSANN An Attention-based Graph Neural Network for Heterogeneous Structural Learning
|
机器学习/深度学习 搜索推荐 算法
【推荐系统论文精读系列】(十)--Wide&Deep Learning for Recommender Systems
具有非线性特征转化能力的广义线性模型被广泛用于大规模的分类和回归问题,对于那些输入数据是极度稀疏的情况下。通过使用交叉积获得的记忆交互特征是有效的而且具有可解释性,然后这种的泛化能力需要更多的特征工程努力。在进行少量的特征工程的情况下,深度神经网络可以泛化更多隐式的特征组合,通过从Sparse特征中学得低维的Embedding向量。可是,深度神经网络有个问题就是由于网络过深,会导致过度泛化数据。
187 0
【推荐系统论文精读系列】(十)--Wide&Deep Learning for Recommender Systems