神经网络知识蒸馏

简介: 翻译:《Distilling the knowledge in a neural network》

We found that the best results were generally obtained by using a condiderably lower weight on the second objective function. Since the magnitudes of the gradients produced by the soft targets scale as $1/T^ 2$ it is important to multiply them by $T^2$ when using both hard and soft targets.

贴个代码例子

KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
F.cross_entropy(outputs, labels) * (1. - alpha)

代码来源:github:knowledge-distillation-pytorch.

在这行代码中,第一个损失是软目标的损失,使用了$T^2$对量级进行调整,第二个损失是真实label的硬目标损失。因为这里是使用了pytorch的,要注意的是这个output不是softmax的输出,pytorch自己的F.cross_entropy()会对outputs进行处理,所以不需要在最后加一个softmax层。


一个非常简单的几乎可以提升所有机器学习算法的表现的方法是在同一个数据集上训练多个不同的模型,然后对他们的预测结果取均值。然而呢,使用一个集成模型做预测太过麻烦,而且在用户较多时可能需要很大的计算量,尤其在其中的子模型是较大的神经网络时。Caruana和他的小伙伴们研究表表明一个集成模型的知识是可以被压缩到单独一个更容易部署的模型,我们在这个方法的基础上研究出一个不同的压缩技巧。我们在MNIST数据集上实现了惊人的结果并且发现我们能通过将一个集成模型的知识蒸馏到一个单独的模型中来提高大型商业系统的模型的表现。我们也介绍了一个新的包括一个或多个full models和一些针对性模型(能够区分full model无法确定的细粒度)的集成方法。不像多专家模型,这些针对性模型可以被快速地并行训练。

Introduction

很多昆虫都有幼虫期,这一时期最适合从环境中吸取能量和营养物质。成虫期则是一个完全不同的状态,更易于种族迁移和繁衍。在大规模机器学习中,我们我们通常在训练阶段和部署阶段使用相似的模型,即使他们的需求完全不同:对于语音或物体识别这样的任务,训练时必须从冗余度很高的数据集中提取结构,但是并不需要实时处理,并且它需要大量的计算。部署给大量用户则有更严格的在时间和计算资源上的要求。和昆虫的类比告诉我们,如果想要更容易地从数据中提取结构,我们应该愿意训练老师模型。这个老师模型可以是几个分开训练的模型的集成活着一个非常大的使用了类似dropout等正则化方法的模型。我们称为“蒸馏”的方法可以实现知识在老师模型和小模型中的转移,所以更适合这种部署。Rich Caruana早就探究过这个策略的另一个版本。在他们的论文中,他们得出可信的结论:从大型集成模型中获取的知识可以转移到小的单一模型上。

一个妨碍对这个先进方法的探究的理解上的阻碍是“我们倾向于用学习到的参数值来识别训练模型中的知识,但这样我们就没办法弄清楚怎么在改变模型形式的情况下保持同样的知识。一个更抽象的对知识的看法是,知识是从输入向量到输出向量的映射。对学习分辨不同类别的老师模型来说,一般训练目标都是最大化正确答案的平均log概率,但是一个副作用是这个训练模型会给所有不正确的答案分配概率,即使当这些概率都非常小的时候,它们中的一些也会比其余的大很多。不正确的答案的相对概率告诉我们很多这个老师模型是如何泛化的。一个BMW的图片,可能会有小概率被识别成一个垃圾车,但是这个概率仍然比把它识别成胡萝卜的概率大。

被广泛接受的一点是,用来训练的目标函数需要尽可能反映出用户的真实目标。尽管如此,在真实目标是在新模型上泛化好时,模型也经常被训练来取得在训练数据上的好的表现。很明显训练模型是为了有更好的泛化效果,但是这要求知道如何正确的进行泛化,而这个信息通常是不可得的。当我们从一个大模型往小模型蒸馏信息的时候,我们能像训练大模型一样训练小模型进行泛化。如果这个老师的模型泛化很好,比如说,它是一个大型多模型集成的结果的平均,一个按照这样方法训练的小模型在测试集上表现会比普通方法训练的小模型要好。

一个可以将老师模型的泛化能力转移到小模型上的方法是使用老师模型产生的分类概率作为训练这个小模型的“soft targets”软目标。对这个转移阶段,我们能够用同样的训练集或者一个分离的“transfer”集。当这个老师模型是几个简单模型的集成时,我们能使用他们的预测分布的算术或几何均值作为软目标。当这个软目标有比较高的熵,它们比硬目标能提供更多的信息,并且在梯度上方差更小,所以小模型能经常在比原来的老师模型使用的数据集更小的数据上训练并且使用一个更高的学习率。

对于像MNIST这样的任务,老师的模型通常会分配给正确的结果非常高的置信度,大部分和学习到的函数有关的信息会存在软目标中非常小的概率的比例中。比如说,一个2可能会被分配到$10^{-6}$的概率被认为是3或者$10^{-9}$的概率被认为是7。这些信息定义了数据间的相似结构,但是它对转移阶段的交叉熵损失函数影响很小,因为这个概率太接近0啦。Caruana和他的同事们巧妙地绕开了这个问题,他们使用logits而不是softmax产生的概率作为小模型学习的目标,随后他们最小化老师模型和小模型产生的logits间的方差。我们采用的方法,称为“蒸馏”,就是提升最终的final softmax的温度知道老师模型产生一个合适的目标的软集合。我们在训练小模型的时候也使用一样的高温来匹配这些软目标。我们之后会展示匹配老师模型的logits只是蒸馏的一个特别例子。

被用来训练小模型的转移集合可以完全由未标注数据组成。我们也能使用原始的训练集合。我们发现使用原始的训练集合效果也很好,尤其在我们在目标函数中添加了鼓励小模型预测目标和匹配老师模型提供的软目标的项之后。通常,小模型不能够完全匹配软目标,但是在正确的答案的方向出现错误其实是有帮助的。

Distillation

神经网络通常用softmax输出层来产生类别概率,softmax输出层可以通过比较将为每个类别计算的logits$z_i$转化成概率$q_i$。$T$是一个经常被设为1的温度。使用越高的$T$产生的概率分布越平滑。
$$ q_i=\frac{exp(z_i/T)}{\sum_jexp(z_j/T)} $$
在蒸馏的最简单的形式中,知识在转移集合上进行训练,并且使用在softmax上带高温的老师模型对集合中每个样例产生的软目标分布,而后知识可以通过以上过程被转移到蒸馏模型上。在训练蒸馏模型时,会使用一样的高温,但是之后会使用温度1。

当转移集合的正确标签是可知的时,这个方法也能通过训练蒸脸模型产生正确标签来提升。一个方法是使用正确的标签对软目标进行修正,但是我们发现一个更好的方法是简单的使用两个目标函数的加权平均。第一个目标函数是软目标的交叉熵,这个交叉熵是通过在蒸馏模型和老师模型的softmax中使用相同的高温来计算的。第二个目标函数是正确标签的交叉熵。这个是在蒸馏模型中使用一样的logits但是使用温度1计算得出的。我们发现对第二个目标函数使用相对较低的权重往往能取得更好的结果。因为软目标产生的梯度的量级会用$1/T^2$缩放,所以同时使用硬目标和软目标时要用$T^2$与他们相乘。这确保了即使在蒸馏的温度发生改变的时候,硬目标和软目标的相对贡献保持不变。

Matching logits is a special case of distillation

转移集合中的每个样例都会贡献一个交叉熵梯度$dC/dz_i$,对应于蒸馏模型的每个logit $z_i$。如果这个老师模型有logits $v_i$,它可以产生软目标概率$p_i$并且转移训练是在温度$T$上完成的。梯度通过以下公式计算:
$$ \frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{e^{z_i/T}}{\sum_je^{z_j/T}}-\frac{e^{v_i/T}}{\sum_je^{v_j/T}}) $$
如果温度比logits的量级的高,我们会有近似:
$$ \frac{\partial C}{\partial z_i}\approx \frac{1}{T}(\frac{1+z_i/T}{N+\sum_jz_j/T}-\frac{1+v_i/T}{N+\sum_jv_j/T}) $$
如果我们现在假设从每个转移样例中得到的logits是零均值的,所以$\sum_jz_j=\sum_jv_j=0$。所以上面的式子可以化简成:
$$ \frac{\partial C}{\partial z_i}\approx\frac{1}{NT^2}(z_i-v_i) $$
所以在高温限制下,蒸馏其实就是最小化$1/2(z_i-v_i)^2$。在低温情况下,蒸馏对匹配比均值小很多的的logitas的关注较小。一般情况下,logits并不会被受限于用来训练老师模型的损失函数,它们总是有噪声的,所以蒸馏的特点其实是一个优点。在另一方面,非常负的logits可以传输与老师模型获取的知识有关的有用的信息。哪个影响更大是一个经验性的问题。我们发现当蒸馏的模型太小以至于不能获取老师模型中的全部知识时,中等温度的效果最好,这也说明了忽略特别负的logits值是有用的。

Preliminary experiments on MNIST

为了了解蒸馏的工作效果,我们在60000个样例上训练了一个大的拥有两个包括1200个线性修正隐藏单元隐藏层神经网络。这个网络用了dropout和weight-constraints进行正则化。dropout能够被看成训练一个指数级共享权重的大型集成模型的方法。此外,输入图像在任意方向进行啦抖动。这个网络实现了67测试错误,然而一个更小的拥有两个包含800个线性修正隐藏单元的隐藏层的网络在没有正则化的情况下实现了146个错误。

但是如果更小的网络通过添加匹配由20温度的大网络产生的软目标的任务来进行正则化,它可以实现74的测试错误。这说明软目标能将大量的知识转移到蒸馏模型,包括在翻译训练数据上学习的关于如何泛化的知识,即使转移集合不包括任何与翻译有关的内容。

当这个蒸馏网络在它的两个隐藏层有300或更多的单元,温度都高于8时会给更多相对相似的结果。但当隐藏单元数减少到30时,温度在2.5到4的范围内时效果会明显更好。

接下来我们尝试删掉所有的数字3的例子。从蒸馏模型的角度看,3是一个它从来没有见过的神秘的字母。尽管如此,蒸馏模型也只产生了206个测试错误,其中133个都是和3有关的。大部分错误都是因为学到的对3的偏置过低。如果这个偏置增加3.5,这个蒸馏模型将会只产生109个错误,其中14个和3有关。所以当偏置正确时,蒸馏模型得到98.6%的在3上面的正确率,即使在训练时从没有见过3。如果转移数据只包括训练数据中的7和8,蒸馏模型会有47.3%的测试错误,当对7和8的偏置降低7.6时,这个错误率会降低到13.2%。

Experiments on speech recognition

在这部分,我们研究用于ASR中的集成DNN声学模型的效果。我们发现我们在论文中提出的蒸馏策略实现了目标效果:把一个集成模型蒸馏成一个比直接从训练数据中学习的同等大小的模型表现更好的简单模型。

先进的ASR系统目前使用了DNN来将一个从波形中提取的特征的时间上下文映射到一个隐式马尔科夫离散状态的概率分布中。更特别地,DNN产生了在三音状态集群上的概率分布,然后解码器在HMM状态上找到一条通路,这个通路是使用高概率状态和在语言模型下产生一个可能的标注的最佳折衷。

尽管可以用这样的方式来训练一个DNN,即在所有可能的路径上边缘化解码器。通过最小化神经网络预测结果和每个观察值的状态的真实序列的强制对齐所给的标签之间的交叉熵来训练DNN实现逐帧分类是一个很常见的操作:
$$ \theta=arg\max_{\theta^‘}P(h_t|s_t;\theta^‘) $$
$\theta$是我们的声学模型$P$的参数,模型$P$将在时间$t$的观察结果$s_t$映射到关于"正确的"HMM状态$h_t$的概率P(h_t|s_t;\theta^‘)$上。这个模型是用分布式随机梯度下降来训练的。

在我们使用的架构中,有8个隐藏层,每层包含2560个线性修正单元,还有一个包含14000个标签(HMM的目标$h_t$的softmax层。输入是26帧的40 Mel-尺度滤波组系数,每帧会往前进10ms,我们要预测第21帧的HMM状态。参数的总量一共85M。这是一个有点过时的声学模型,可以被看作一个baseline。为了训练DNN声学模型,我们使用了大概2000小时时长的英语口语数据,产生了大概700M的训练样例。这个系统实现的帧准确率是58.9%,词错误率是10.9%。

Results

我们训练了10个不同的模型来预测$P(h_t|s_t;\theta)$,使用和baseline一样的架构和训练步骤。这个模型的初始参数值是随机初始化的,我们发现这个导致了训练的模型间的明显的差异性,所以集成模型的预测结果会比单个模型更好。我们也曾通过让每个模型中不同的数据集上训练来增加模型之间的差异性,但我们发现这对我们的结果没什么影响,所以我们还是用了更简单的方法。至于蒸馏,我们尝试了不同的温度[1,2,5,10] 并且对于硬目标的交叉熵我们使用相对权重0.5。

我们的蒸馏方法同简单地使用硬标签来训练一个模型相比,能从训练集中提取更多有用的信息。通过使用10个模型的集成实现的帧分类上的准确率的提升中,百分之八十都被转移为与我们在MNIST上的初步试验中观察到了改进相似的蒸馏模型。由于目标函数的不匹配,集成对WER的最终目标(23K字的测试集)的提升比较小,但集成在WER上的提升也被转移到蒸馏模型中。

我们最近发觉了一些通过匹配训练过的较大模型的类概率来学习小的声学模型的相关工作。然而,他们在温度1上用一个大的么有标签的数据集做蒸馏,他们的最好的蒸馏模型只将小模型的错误率减少了大模型和小模型之间错误率差距的28%。

Training ensembles of specialists on very big datasets

在集成模型上训练是一个利用并行计算的优点的一个非常简单的方法,同时集成在测试时需要大量计算的问题也可以通过蒸馏解决。然而,集成还有一个问题是:如果子模型是大神经网络并且数据集很大,在训练时需要的计算即使很容易并行时间,也会数量过高。

在这一部分,我们给出了数据集的例子,我们展示了如何学习关注不同的可混淆子集的专家模型能减少集成所需要的计算量。关注细粒度的专家模型的主要问题是,他们很容易过拟合,我们会介绍如何通过使用1软目标来防止过拟合。

The JFT dataset

谷歌对JFT数据集的baseline模型是一个在异步梯度下降方法上训练了六个月的深度卷积神经网络。训练使用了两种并行。第一种,这里有很多的神经网络的复制在不同的核上运行,处理训练集里的不同批次。每个复制在它当前的批次上计算平均梯度,并且把这个梯度送到标准的参数服务器,这个服务器会送回参数的新的值。这些新的值反映了参数服务器从它将参数送到复制后接收到的所有的新的值。第二种,每个复制通过将神经元的不同子集放在每个核上从而在在多个核上展开。集成训练是第三种能被其余两种环绕式处理的并行方式,但必须在有更多的核的情况下。等待好几年来训练一个集成模型并不是一个选择,所以我们需要一个更快的提升baseline的方法。

Specialist Models

当类别的数量非常大的时候,老师模型应该是一个包含了在所有数据上训练的通用模型和多个专家模型的集合,每个专家模型都是在一些来自基于很容易混淆的子集上的数据进行训练的(比如不同类型的蘑菇)。这类专家模型的softmax可以通过将所有它不关心的类别组合成一个单一的垃圾类别来做的很小。

为了减少过拟合和共享学习低层特征的解码器的工作,每个专家模型都是用同样模型的权重来初始化的。然后我们通过训练专家模型来修改这些权重,专家模型会在一半来自于它的特有子集上的样例和一半随机从剩余集合中选择的样例上进行训练。在训练后,我们能通过增加垃圾类别的logits乘以专家类抽样比例的对数来修正我们的偏移训练集。

Assigning classes to specialists

为了给专家模型们导出类别的分组,我们决定去关注我们的全网络经常混淆的类别。即使我们能计算出混淆矩阵并且用它来找到这样的集群,我们选择一个更简单的方法,这方法并不需要用真的标签来构造集群。

特别的,我们在我们通用模型的预测结果的协方差矩阵上应用一个集群算法,所以多个经常一起预测的类别$S^m$的集合将作为专家模型$m$的目标。我们应用一个K-means的在线版本来获取合理的集群。我们尝试了几个不同的集群算法,都产生了相似的结果。

Performing inference with ensembles of specialists

在研究在专家模型被蒸馏时会发生什么之前,我们想看看包含专家模型的集成模型表现能有多好。除了专家模型,我们也总会用一个通用模型,所以我们能处理没有专门设置专家模型的类别并且决定去使用哪个专家模型。给定一个输入图像$x$,我们进行"top-one"分类:

  1. 对于每个测试样例,我们根据通用模型确定$n$个最有可能的类别。将类别的集合称为$k$。在我们的实验中,我们使用$n=1$。

  2. 我们接下来使用所有的专家模型$m$,m对应的可混淆类别的子集$S^m$,和$k$有一个非空交集,我们称这个集合为$A_k$。然后我们发现全概率分布所有分类上的$q$,$q$可以最小化:
    $$ KL(p^g,q)+\sum_{m\in A_k}KL(p^m,q) $$
    $KL$是KL散度,$p^mp^g$是一个专家模型或者通用模型的概率分布。分布$p^m$是所有专家类别$m$加上一个垃圾类别的分布。所以从全$q$分布中计算它的KL散度时,我们把所有的概率加起来。

上面的公式没有一个一般的闭式解,即使在所有的模型对每个类别产生一个概率时,这个解不是算术平均就是几何平均。我们将$q$设为$softmax(z)(T=1)$,我们使用梯度下降来最优化logits$z$。注意这个优化的过程需要在所有图像上进行。

Results

从训练好的baseline全网络开始,这个专家模型训练的都非常快。所有的专家模型都是独立训练的。使用61个专家模型,在测试准确率上有4.4%的提升。我们也报告了条件测试准确率,这种准确率只考虑属于专家模型对应的类别的样例,并且把我们的预测结果限制在子集范围内。

对我们的JFT专家模型试验,我们训练了61个专家模型,每个都有300个类别(加上垃圾类别)。因为专家模型对应的类别的集合并不是相脱节的,我们经常会有多个专家模型包括同一个图片类别。当我们有更多的专家模型对应一个特定的类别时,准确度的提升将会更大,因为独立的专家模型很容易并行训练。

Soft Targets as Regularizers

我们使用软目标取代硬目标是因为软目标能携带更多有用但不能被硬目标编码的信息。在这部分我们详细介绍了使用特别少的数据来拟合85M的参数会造成很大的影响。只使用3%的数据,用硬目标训练baseline会导致很严重的overfitting(我们使用了early-stopping,因为在到达44.5%后准确率掉的比较明显)。而同样的模型用软目标在全训练集上进行训练则能够发现更多的信息。我们甚至没有使用early-stopping。使用软目标的系统简单地收敛到57%。这说明软目标是一个很有效的方法,它可以把一个在所有数据上训练过的模型的规律传递给另一个模型。

Using soft targets to prevent specialists from overfitting

我们在JFT数据集上使用的专家模型将所有的非专家模型对应类别压缩到一个垃圾类别中。如果我们允许专家模型在所有类别上进行full softmax,这里或许会有一个比early-stopping更好的防止过拟合的方法。一个专家模型是在他对应的专家类别上进行训练的,这意味着它的训练集的有效大小要小很多并且在自己的类别上很容易过拟合。这个模型不能通过把专家模型变小来解决,否则我们会失去很多有用的我们通过对所有非专家类别建模获得的转移效果。

我们的实验使用了3%的语音数据,实验显示出如果一个专家模型是用通用模型的权重初始化的,我们能让他保留几乎所有的通过用软目标进行训练得到的关于非专家类别的知识。软目标是通用模型提供的。我们目前还在研究这个方法。

Relationship to Mixtures of Experts

在数据子集上训练的专家模型的使用和使用一个门网络来计算将每个样例发给每个专家的概率的混合专家模型很相似。同时,这些专家也在学习处理分配给他们的样例,门网络则学习根据专家的相对判别性能将每个样例分配给哪个专家。使用专家的相对判别性能来决定学到的分配比直接简单的对输入向量进行聚类然后把专家分配给每个集群要好很多,但是它会让并行训练变得更困难:首先,对应每个专家的带权重的训练集按一个由其余所有专家决定的方法进行变化。其次门网络需要在同一个样例上比较不同专家的表现,以了解如何修改它的分配概率。这些困难意味着在可能最有益的体制下混合专家模型很少被使用:包含明显的不同子集的巨大数据集的任务。

很容易对多个专家的训练进行并行。我们首先训练一个通用模型,然后使用混淆矩阵来决定专家在哪个子集上进行训练。等这些子集被定义后,专家模型就可以完全独立地进行训练。在测试时间我们使用通用模型等预测来决定哪个专家模型是相关的,只有这些专家模型需要被使用。

Discussion

我们发现蒸馏方法在从一个集成模型或者一个大的高度规范化的模型中转移知识到一个更小的模型时效果非常好。在MNIST上,蒸馏的效果在转移数据集被用来训练蒸馏模型时效果尤其明显。对于android语音搜索所使用的一种深度声学模型,我们已经表明,通过训练集成深度神经网络模型所取得到提升都可以被蒸馏到一个同样大小的神经网络模型,这样更易于部署。

对于大神经网络,即使训练一个完整的集成也是不可行的,但我们也表明了一个非常大的训练了很长时间的网络也能通过学习多个专家网络来进行提升,每个专家网络都在一个高度易混淆的集群中进行学习。我们还没有成功将专家模型里面的知识蒸馏到一个大模型中。

相关文章
|
8月前
|
机器学习/深度学习 存储 算法
【轻量化网络】概述网络进行轻量化处理中的:剪枝、蒸馏、量化
【轻量化网络】概述网络进行轻量化处理中的:剪枝、蒸馏、量化
283 0
|
机器学习/深度学习 存储 算法
神经网络中的量化与蒸馏
本文将深入研究深度学习中精简模型的技术:量化和蒸馏
113 0
|
8月前
|
机器学习/深度学习 存储 计算机视觉
【论文速递】TPAMI2022 - 自蒸馏:迈向高效紧凑的神经网络
【论文速递】TPAMI2022 - 自蒸馏:迈向高效紧凑的神经网络
|
5月前
|
机器学习/深度学习 人工智能 数据挖掘
通义语音AI技术问题之自蒸馏原型网络的构成如何解决
通义语音AI技术问题之自蒸馏原型网络的构成如何解决
47 0
|
8月前
|
机器学习/深度学习 Shell 计算机视觉
【论文速递】CCDC2021 - 轻量级网络的结构化注意知识蒸馏
【论文速递】CCDC2021 - 轻量级网络的结构化注意知识蒸馏
|
机器学习/深度学习 存储 人工智能
CVPR 2022 | 这个自蒸馏新框架新SOTA,降低了训练成本,无需修改网络
CVPR 2022 | 这个自蒸馏新框架新SOTA,降低了训练成本,无需修改网络
199 0
|
机器学习/深度学习 Shell 计算机视觉
【论文速递】CCDC2021 - 轻量级网络的结构化注意知识蒸馏
【论文速递】CCDC2021 - 轻量级网络的结构化注意知识蒸馏
197 0
|
机器学习/深度学习 存储 计算机视觉
【论文速递】TPAMI2022 - 自蒸馏:迈向高效紧凑的神经网络
【论文速递】TPAMI2022 - 自蒸馏:迈向高效紧凑的神经网络
897 0
|
机器学习/深度学习 算法
【32】多教师网络进行联合蒸馏测试
【32】多教师网络进行联合蒸馏测试
200 0

热门文章

最新文章