离线蒸馏奢侈?在线蒸馏难?都不再是问题DKEL统统帮你解决,蒸馏同质化问题也解决!

简介: 离线蒸馏奢侈?在线蒸馏难?都不再是问题DKEL统统帮你解决,蒸馏同质化问题也解决!

离线知识蒸馏是一种需要昂贵资源训练教师网络,然后将知识蒸馏到学生网络进行部署的两阶段 Pipeline 。另一方面,在线知识蒸馏是一种一阶段策略,通过互相学习和合作学习来缓解这种需求。最近, peer协同学习(PCL)将在线集成,基础网络(学生)和时间平均教师(教师)的协作学习集成到有效知识构建中。然而,由于学生和教师之间的同质化程度高,PCL中的模型偶尔会崩溃。

在本文中,分析了高同质化的原因,并提出了解决方案。通过独立教师生成一个与学生网络分离的解耦知识。这种设计可以增加网络之间的多样性,并减少模型崩溃的可能性。为了获得早期解耦知识,为教师设计了一个初始化方案,并在理想条件下进行2D几何分析实验,以展示该方案的有效性。

此外,为了提高教师的监督弹性,设计了一个衰减的集成方案。它将教师的知识集成到一个动态权重较大,随着训练过程逐渐减小的知识中。集成的知识在早期训练中作为强大的教师,而减少权重的集成知识可以在潜在过拟合的教师监督下消除分布偏差。进行了蒙特卡罗模拟以评估收敛性。

在CIFAR-10,CIFAR-100和TinyImageNet上进行了大量实验,展示了作者方法的优势。消融研究和进一步分析证明了其有效性。

代码:https://github.com/shaoeric/Decoupled-Knowledge-with-Ensemble-Learning-for-Online-Distillation

1 Introduction

近年来,深度学习逐渐在计算机视觉领域占据主导地位。计算机视觉的主流任务,如图像分类和目标检测,在深度学习技术的帮助下也取得了惊人的成功。庞大的网络由于其强大的表示能力,在上述任务中往往能获得更好的特征提取性能。然而,考虑到系统的实时性能和用户体验,这样的庞大网络很难满足产品的要求。

为了解决这个问题,模型压缩技术受到了研究行人的广泛关注。流行的模型压缩技术包括模型剪枝,模型量化,知识蒸馏和轻量级模块设计。

知识蒸馏(KD)提供了一种有效的学习方法,通过模仿来自更优性能的教师网络的输出,来提高紧凑的学生网络。Hinton等人引入了蒸馏温度,以软化教师网络的类别概率分布,从而在学生网络上进行更有效的监督。基于这个想法,已经开发出各种各样的知识来优化学生网络,如提示表示,基于注意的特征图,关系信息,对比表示等。

上述知识和方法遵循两阶段训练 Pipeline ,并要求一个性能更好的教师网络,通常称为传统离线知识蒸馏。然而,对于资源受限的训练任务,可能既没有现成的教师网络,也没有足够的时间和设备来训练一个教师网络。

为了缓解这个问题,开发了在线知识蒸馏,以在单阶段端到端训练过程中同时优化多个紧凑的网络。深度互学习(DML)[32](图1(a))表明,模型可以从彼此学习,而不是盲目地跟随盲目,采用这种在线方式学习的模型比单独学习的模型获得更好的性能。这一结论激发了协作学习的大量研究。

集成学习,在现有的在线知识蒸馏中广泛使用,如KDCL(图1(b)),通过减少每个个体网络的方差,构建一个强大的伪教师网络,从而减轻在线知识蒸馏早期的不准确监督问题。一些最近的工作趋势,将网络的某个分支命名为伙伴[3; 27]。为了构建更强大和稳定的监督,同行合作学习(PCL)引入了时间平均教师(教师)来提高基础网络(学生)(图1(c)),基本上是一种应用网络参数的指数移动平均(EMA)的技术,该技术在批量归一化和YOLO中也有应用。

在这项工作中,以同行合作学习(PCL)作为基准方法,作者提出了解耦知识与集成学习(DKEL),如图1(d)所示。作者提出的DKEL方法包括两种形式的迁移知识,即解耦知识和衰减集成知识。解耦知识被引入来解决PCL中的模型崩溃问题,并详细分析了潜在的原因。

作者通过扩展教师的解决方案空间来提出解决方案,这涉及构建一个独立的教师并设计一个教师初始化方案来生成早期的解耦知识。作者进行了一个理想的实验来验证这种方法,并将结果以2D几何格式呈现。

此外,为了促进强大的有效监督,作者设计了一个衰减的集成方案,其中教师同伴的logits被组装并分配一个衰减的权重用于学生监督。由教师生成的解耦和衰减集成知识增强了学生监督信息的有效性,同时教师的性能从EMA更新中受益。

工作贡献如下:

  1. 在线蒸馏方法提出了解耦知识与集成学习的解耦知识,并展示了其在现有方法上的理论及实证优势。
  2. 为了应对教师和学生的同学之间的高同质化导致模型崩溃问题,构建了一个具有解耦知识的教师网络。
  3. 设计了一个衰减集成知识的方案,作为早期的稳健教师,加速优化并避免在晚期训练中提供过拟合监督。
  4. 所提出的方法在多个数据集和不同架构上进行了评估。与SOTAs进行了性能比较。

2 Related work

知识蒸馏通常用于模型压缩,这被分为离线和在线知识蒸馏:

离线知识蒸馏需要一个预训练的教师网络和一个学生网络,学生网络同时学习教师网络和真实值。Hinton等人提出了知识蒸馏的概念,并提出了通过将教师网络和学生网络的软分布对齐来将知识从复杂的教师网络转移到紧凑的学生网络的方法。通常,离线知识蒸馏倾向于训练一个具有惊人性能的学生网络,对于压缩大型网络非常有效,然而,它通常遵循两阶段范式,这大大增加了训练时间和计算开销。

在线知识蒸馏是一种端到端的一阶段训练方法,它不需要额外的训练时间和计算资源来预训练一个庞大的教师网络。张等人开创了一种深度互学习方法,探讨了在线知识蒸馏的可行性,该方法在多个具有相同输入的并行模型之间相互转移知识。为了保持多个网络的多样性,郭等人随机为每个个体网络的相同输入进行数据增强,并将所有输出logits汇总为一个集成软标签来优化每个网络。吴等人使用堆叠策略并利用时间平均教师来推导出训练和推理的强健预测。

3 Methodology

Formulation

给定一个包含类别的目标数据集,其中包含个训练样本,其中是第个图像,是对应的 GT 值,且。将样本输入到教师和学生网络中,对应的推导输出logits分别表示为和。softmax函数将logits规范化为概率向量,通常使用Kullback-Leibler(KL)散度来最小化教师和学生之间的概率分布差距。通常,需要通过蒸馏温度超参数对概率分布进行平滑。

具体而言,知识蒸馏损失函数可以表示为:

KL函数的值只有在时才会等于0。

Pcl

本节简要描述了PCL方法。将两种类型的模型集成到一个统一框架中:一个基础模型和一个时间平均教师,其中基础模型被视为学生,时间平均教师简称为教师。

令表示学生的第个同行,表示相应的教师同行。如图1(c)所示,PCL为第个同行将输入增加到,并推导出相应的扁平特征和logits。特征通过堆叠组装成同行集成,就像一样,监督每个logits。 GT 值使用交叉熵损失来避免盲目引导盲目,并使用来改进集成的容量。

此外,使用具有更好泛化能力的教师同行通过进行知识传递。

假设表示同行数量,PCL对于学生的损失函数可以表示为:

其中,表示交叉熵损失,表示对同行集成的交叉熵损失,表示正则化损失。

其中,是一个集成函数。

在学生每次迭代后,使用EMA更新教师参数。

在这里,是一个超参数,和分别是优化的学生和教师。

Decoupled knowledge

解耦知识关注PCL中潜在的模型崩溃问题。如图2所示,PCL引发某些可能导致模型训练崩溃的情况。检查崩溃网络的参数矩阵,发现为了使教师和学生的知识分布尽可能接近,崩溃网络中的所有参数都被迫设置为零,logits都是零向量。形式化地表示为:对于,满足,其中网络和都作为具有全部零参数的线性变换。

崩溃的 essential 原因是教师和学生的同行(简写为和)同质化程度过高,导致蒸馏损失过小。此外,在和同质化的情况下,logits 值由于正反馈而变小,导致交叉熵损失过小。

总结而言,和同质化将导致  过小。此时,SGD 优化器中的 L2 正则化项被优先优化,导致所有参数的绝对值越来越小,直到网络崩溃。

和的同质化主要是由于它们共享相同的 Backbone 网络,而相对较小的独立同行,由于其相对较小的解空间,可以轻松触发上述崩溃条件。

为了减少同质化,设计了一个简单但有效的解耦策略,其中构建了一个独立的教师来将解耦知识传递给学生,教师和学生之间没有共享参数。

如图3所示,增强被输入到学生的 Backbone 并依次传递给第个同行,得到的logits。对于教师,除了第个增强以外的所有增强都被输入到教师 Backbone 及其个同行中,得到相应的logits集合,因此解耦知识的损失函数可以表示为:

在构建解耦知识后,必须考虑如何初始化教师网络。在当前的神经网络中,随机初始化被广泛使用,但是它可以导致早期优化方向的偏差,这在本文中被称为早期知识偏差。

为了说明这一点,如图4所示,作者设计了一个理想的实验,并使用2D几何风格展示了结果。数据和网络的分布被抽象地表示为点,两个点之间的距离理想化地表示它们的分布差距。和分别表示真实数据分布和 GT 分布,表示学生第个同行的分布,表示教师第个同行的分布。是两个学生同行的堆叠集成,位于和的中点。

为了使图形的表示更清晰,只展示了和的优化。在三个监督下优化:优化到,使用;优化到,使用;优化到,使用。

根据向量求和法则,优化的同行分布可以表示为:

然后,EMA优化到,这可以表示为:

由于初始化的随机性,有可能导致和位于的相反一侧,即也是一个明显的钝角,以及。这将导致优化的偏移,从而使优化的受到偏置的影响。

因此,提出了一种教师网络的初始化方案,该方案包括两个步骤:

  1. 将学生的权重复制到教师中,以确保两个网络具有相同的初始分布。
  2. 使用交叉熵在仅几步内优化教师网络,学习率较小。

如图4(b)所示,这种初始化将调整到与在同侧的位置,作为参考。与使用随机初始化的相比,它加速了向方向靠近,从而提高了性能。得益于这一点,优化的和比随机初始化更接近,从而提高了性能。

此外,从统计的角度来看,使用EMA方法训练的相当于参数的第一阶矩估计以及Adam优化器中的动量[28],如果的初始性能可以得到改善,那么在训练过程中的全局性能也会得到改善。

Decaying ensemble strategy

集成学习可以用于增强多个网络的泛化性能并得到更好的评估。在在线知识蒸馏领域,它通常用于构建具有多个logits的强健的监督信息。PCL中的教师和提出的解耦知识直接监督学生的训练,而不涉及logits集成,这可能导致由于早期教师的能力有限而学生的性能受到限制。

随着网络的持续优化,教师的同行会逐渐适应或甚至过拟合到。尽管其集成通常产生最佳泛化知识,但仍然存在向分布的偏差,这可能为学生提供过拟合监督。

基于上述说法,进一步提出了一种衰减集成策略来提高解耦知识。对于第个学生同行,的损失函数和教师同行集成的损失函数可以表示为:

其中,表示第个教师同行,表示的交叉熵损失,表示的交叉熵损失,表示到的损失,表示到的损失。

其中,的目的是避免高同质化和潜在的模型崩溃。

为了减少过度拟合的教师网络对分布的偏差,采用指数权重函数对进行衰减,随着训练轮次的增加而衰减,这可以表示为:

其中, 是一个正的衰减超参数。因此,的 DKEL 损失函数可以表示为:

为验证上述基于理想实验的假设,设计了使用三个网络同行进行蒙特卡罗模拟。假设分布被表示为点,两个点之间的距离表示对应分布的损失,优化方向是从点到目标点的向量方向,集成表示为两个点的交点。

在这里,学习率是0.1,只有一个步骤用于构建解耦知识,而EMA的值为0.5。进行了90个epoch的实验,每个epoch包含10,000次尝试,以便比较和之间的分布差距,并报告每个epoch的平均差距。

如图5所示,蓝色点线表示提出的解耦知识,简称为DK在图中,绿色点线和橙色点线分别表示PCL和DKEL。结果表明,DK仅稍微优于PCL,两种方法在约第70个epoch时收敛。然而,DKEL利用解耦知识在第一10个epoch中加速优化过程,通过创建强大的监督。这使得网络收敛更快,在第60个epoch时收敛。此外,DKEL中的衰减集成方案产生的差距比其他两种方法小。

根据上述描述,算法1描述了优化过程。优化后,使用教师进行推理和部署,由于教师和学生具有相同的模型结构,因此不会产生额外的负载。

4 Experiments

本文讨论了实验中使用的数据集,实现细节,实验设置,并使用SOTAs进行比较,以评估所提出方法的优势。作者还对所提出的进行了方法的有效性和鲁棒性的消融分析。

Datasets

为评估作者提出的算法,作者在几个数据集上进行了实验,包括CIFAR-10,CIFAR-100 [13]和TinyImageNet [4]。CIFAR-10包含50,000个训练图像和10,000个验证图像,来自10个目标类别,其中每个图像是一个 RGB图像。CIFAR-100与CIFAR-10具有相同的图像数量和大小,但来自100个目标类别。TinyImageNet由200个类别的100,000张彩色图像组成,每个类有500个训练图像和50个验证图像。在设计的实验中,报告了验证精度。

Implementation details

所有实验基于pytorch实现。模型由Chen等人提供,包括ResNet,VGG,DenseNet和WideResNet,采用三同行架构,如PCL所示。对于数据增强,在训练中应用随机水平翻转,随机裁剪和归一化,并将图像大小分别调整为32x32(对于CIFAR-10/100)和64x64(对于TinyImageNet)。

使用SGD with Nesterov momentum优化网络,其中momentum为0.9,权重衰减为。批大小设置为128,设置为0.5。解耦知识需要一次优化,学习率设置为0.01,在CIFAR-10/100上训练300个周期,在TinyImageNet上训练100个周期。初始学习率设置为0.1,在CIFAR-10/100上的周期和TinyImageNet上的周期衰减为0.1。

Comparison with SOTAs

在本节中,作者将方法DKEL在CIFAR-10/100和TinyImageNet数据集上进行评估,并与先前的在线知识蒸馏工作进行比较,包括DML,CL,ONE,OKDDip,KDCL和PCL。

为了进行公平的比较,参考[14; 27],在比较的方法中应用了三分支架构,包括ONE,CL,OKDDip,PCL和DKEL,并且为DML和KDCL的需求应用了三个并行网络。对于PCL,作者根据论文进行了复制,因为不存在开源代码,并报告了作者复制的结果。

4.3.1 Results on CIFAR-10/100

如表1所示,DKEL相对于PCL在CIFAR-10和CIFAR-100数据集上平均提高了约0.3%,证明了所提出方法的有效性。例如,在CIFAR-10上,DKEL相对于PCL在VGG-16和DenseNet-40-12上分别提高了0.61%和0.52%,在ResNet-32和ResNet-100上分别提高了0.18%和0.16%。

在CIFAR-100上,DKEL相对于PCL在DenseNet-40-12和ResNet-110上分别提高了0.42%和0.4%,在ResNet-32和VGG-16上分别提高了0.37%和0.24%。在CIFAR-10和CIFAR-100上,相对于PCL,DKEL与WRN-20-8的性能差异很小,这可能是因为接近性能极限的网络很难从集成解耦知识中实现更大的改进。

4.3.2 Results on TinyImageNet

表2显示了在TinyImageNet上的准确性能。DKEL在ResNet-18和ResNet-34上分别将 Baseline 提高了约13%。与PCL相比,DKEL在ResNet-18和ResNet-34上分别获得了0.1%和0.04%的改进。

DKEL对ResNet-34的微小改进可能是因为数据集和网络都更加复杂,构建解耦知识的迭代次数太少,无法在训练的早期阶段提供显著有效的监督。

4.3.3 Ablation study

在CIFAR-10/100数据集上使用ResNet-32进行了对提出的两个损失函数的消融研究。

表3报告了结果,并显示了所提出方法的有效性。例如,解耦知识通过在CIFAR-100上提高了 Baseline 0.53%,而衰减集成策略提高了 Baseline 0.5%,两者设计均提高了 Baseline 0.69%。

不同迭代次数对解耦知识的性能有影响。图6说明了ResNet32和DenseNet-40-12在CIFAR-10/100上的性能与迭代次数之间的关系。左列和右列分别报告了CIFAR-10和CIFAR-100上的结果,并且小麦条表示ResNet-32,鲑鱼条表示DenseNet-40-12。

在四个设置中,三个呈现了积极趋势,即解耦知识迭代次数越多,网络性能越好。以上结果基本上符合作者的预期:

  1. 当更多的初始解耦知识可以表示数据集的真实分布时,原网络学习的知识应该更有效。
  2. 可以实现良好的蒸馏性能,从而大大降低训练开销。

具体而言,1-迭代和10-迭代ResNet-32在CIFAR-10上分别取得了94.56%和94.63%的性能,1-迭代和5-迭代ResNet-32在CIFAR-100上分别取得了74.83%和74.96%的准确率。

对于DenseNet-40-12在CIFAR-10上的情况,模型参数的数量比CIFAR-100和CIFAR-10上的ResNet-32少,因此更可能出现局部最优,可以在更多迭代次数后改进。

对超参数的敏感性分析研究。VGG-16和DenseNet-40-12在CIFAR-10/100上进行训练,采用不同的。

如表4所示,VGG-16在CIFAR-10上的性能范围为94.25至94.89,在CIFAR-100上的性能范围为76.73至77.20,而DenseNet-40-12在CIFAR-10上的性能范围为93.64至94.23,在CIFAR-100上的性能范围为72.75至73.36,这意味着性能范围约为0.6%。

此外,当时,网络可能得到较差的结果,因为过大导致的权重过快衰减并收敛为0,仅在两轮优化后。当较小时,的权重收敛为0需要数十个训练周期,并在训练的后期阶段不影响优化目标,表4也显示较小的值对网络性能有益。

探索了几种组合知识对训练性能的影响。组合知识的权重应该随着训练周期的增加而减小,因此进行了余弦衰减和线性衰减方案,并与设计的指数衰减方案进行比较。和分别表示最大训练周期和当前周期,因此余弦衰减方案表示为:

余弦衰减方案表示为:

图7显示,设计的指数衰减方案在几乎所有周期内都取得了最佳性能,而余弦衰减方案则表现最差。在训练稳定性方面,指数衰减方案也优于其他两种,并且可以注意到余弦衰减方案在后期阶段遇到了严重的过拟合,导致性能显著下降。

5 Conclusion

作者提出了一种名为DKEL的在线知识蒸馏方法,该方法使用由时间平均教师生成的解耦知识来避免模型崩溃。作者还设计了一种产生早期解耦知识和一个衰减集成策略来增强早期监督的鲁棒性并减少晚期知识的偏差。理想分析和蒙特卡罗模拟证明了该方法的设计动机和机制,而实验则确认了该方法的有效性和优越性。

参考

[1].Decoupled Knowledge with Ensemble Learning for Online Distillation

相关文章
|
2月前
|
机器学习/深度学习 前端开发 PyTorch
【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!
【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!
30 0
【轻量化:蒸馏】都2023年了,你还不会蒸馏操作,难怪你面试不通过!
|
2月前
|
机器学习/深度学习 编解码 PyTorch
训练Sora模型,你可能需要这些(开源代码,模型,数据集及算力评估)
在之前的文章《复刻Sora有多难?一张图带你读懂Sora的技术路径》,《一文看Sora技术推演》我们总结了Sora模型上用到的一些核心技术和论文,今天这篇文章我们将整理和总结现有的一些开源代码、模型、数据集,以及初步训练的算力评估,希望可以帮助到国内的创业公司和个人开发者展开更深的研究。
|
8月前
|
人工智能 缓存 并行计算
终极「揭秘」:GPT-4模型架构、训练成本、数据集信息都被扒出来了
终极「揭秘」:GPT-4模型架构、训练成本、数据集信息都被扒出来了
345 0
|
11月前
|
存储 PyTorch TensorFlow
恕我直言,你们的模型训练都还不够快
恕我直言,你们的模型训练都还不够快
|
11月前
|
存储 机器学习/深度学习 人工智能
调教LLaMA类模型没那么难,LoRA将模型微调缩减到几小时
调教LLaMA类模型没那么难,LoRA将模型微调缩减到几小时
385 0
|
11月前
|
机器学习/深度学习 自动驾驶 计算机视觉
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点(一)
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点(一)
106 0
|
11月前
|
计算机视觉
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点(二)
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点(二)
116 0
|
11月前
|
机器学习/深度学习 自动驾驶 计算机视觉
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点
175 0
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
大模型训练之难,难于上青天?预训练易用、效率超群的「李白」模型库来了!(2)
大模型训练之难,难于上青天?预训练易用、效率超群的「李白」模型库来了!
135 0
|
11月前
|
机器学习/深度学习 人工智能 并行计算
大模型训练之难,难于上青天?预训练易用、效率超群的「李白」模型库来了!(1)
大模型训练之难,难于上青天?预训练易用、效率超群的「李白」模型库来了!
171 0