这篇论文提出了一种名为CrossKD的简单而有效的知识蒸馏方案,用于对目标检测模型进行模型压缩。现有的目标检测领域的最先进知识蒸馏方法主要基于特征蒸馏,一般被认为比预测蒸馏更好。然而,作者发现预测蒸馏的低效性主要是由于GT信号和蒸馏目标之间的优化目标不一致造成的。
为了解决这个问题,作者提出了CrossKD这种简单而有效的蒸馏方案,该方案将学生模型的检测Head的中间特征传递给教师模型的检测Head,并强制使得交叉Head的预测结果与教师模型的预测一致。通过这种蒸馏方式,学生模型的Head能够避免接收来自GT注释和教师模型预测的相互矛盾的监督信号,从而大大提高了学生模型的检测性能。
在MS COCO数据集上,仅应用预测蒸馏损失,作者的CrossKD将GFL ResNet-50模型的平均精度从40.2提升到43.7,超过了所有现有的目标检测领域的知识蒸馏方法。
1、简介
知识蒸馏(KD)作为一种模型压缩技术,在目标检测领域进行了深入研究,并且近年来取得了卓越的性能。根据蒸馏位置的不同,现有的知识蒸馏方法可以大致分为两类:
- 预测蒸馏
- 特征蒸馏
预测蒸馏(见图1(a))最早在[Distilling the knowledge in a neural network]中提出,指出教师的预测结果的平滑分布对学生的学习更有利,而不是GT值的Dirac分布。换句话说,预测蒸馏迫使学生的预测结果与教师的预测分布相似。
而特征蒸馏(见图1(b))则遵循了FitNet中提出的思想,认为中间特征包含的信息比教师的预测结果更多。它旨在强迫教师和学生之间的特征保持一致性。
预测蒸馏在蒸馏目标检测模型中起着至关重要的作用。然而,长期以来观察到预测蒸馏比特征蒸馏更低效。通常情况下,郑等人提出了一种名为局部化蒸馏(LD)的方法,通过传递定位知识来改进预测蒸馏,将预测蒸馏推向了一个新的水平。尽管与先进的特征蒸馏方法(例如PKD)相比仍有差距,但LD表明预测蒸馏具有传递任务特定知识的能力,这使得学生模型能够从与特征蒸馏不同的角度受益。这激发了作者进一步探索和改进预测蒸馏的动力。
通过观察,作者发现预测蒸馏需要应对GT目标和蒸馏目标之间的冲突,而这在先前的工作中被忽视了。当使用预测蒸馏来训练一个检测器时,学生模型的预测结果被同时迫使蒸馏GT目标和教师模型的预测结果。然而,教师模型预测的蒸馏目标通常与分配给学生模型的GT目标存在很大的差异。
如图2所示,教师模型在绿色圈出的区域中产生了不准确的类别概率,这与GT目标产生了冲突。因此,在蒸馏过程中,学生模型经历了一种矛盾的学习过程,作者认为这是阻碍预测蒸馏实现更高性能的主要原因。
在本文中,作者提出了一种新颖的交叉Head知识蒸馏方法,称为CrossKD,以缓解目标冲突问题。如图1(c)所示,作者建议将学生模型Head的中间特征输入到教师模型的Head,得到交叉Head预测。然后,可以在新的交叉Head预测和教师模型的原始预测之间进行知识蒸馏操作。
尽管CrossKD非常简单,但具有以下两个主要优势:
- 首先,知识蒸馏损失不会影响学生模型Head的权重更新,避免了原始检测损失和知识蒸馏损失之间的冲突。
- 此外,由于交叉Head预测和教师模型的预测都是通过共享部分教师模型的检测Head生成的,因此交叉Head预测与教师模型的预测相对一致。这减轻了教师-学生对之间的差异,增强了预测蒸馏的训练稳定性。
这两个优势使得作者的CrossKD能够高效地从教师模型的预测中提取知识,并且比先前最先进的特征蒸馏方法具有更好的性能。
作者的方法没有任何花里胡哨的东西,但可以显著提升学生模型检测器的性能,实现更快的训练收敛。本文在COCO数据集上进行了全面的实验,详细说明了CrossKD的有效性。
具体而言,仅应用预测蒸馏损失的情况下,CrossKD在1×训练计划下在GFL上达到了43.7 AP,比Baseline高出3.5 AP,超过了之前所有最先进的目标检测知识蒸馏方法。此外,实验证明作者的CrossKD与特征蒸馏方法是正交的。通过将CrossKD与最先进的特征蒸馏方法(例如PKD)相结合,作者在GFL上进一步实现了43.9 AP的性能。
2、本文方法
2.1 准备工作
在介绍作者的CrossKD之前,作者首先简要回顾两种主要类型的知识蒸馏方法:特征蒸馏和预测蒸馏。
1、特征蒸馏
特征蒸馏最早是在[FitNets]中提出的,旨在在教师和学生的潜在特征上强制实现一致性。其目标可以表示为以下公式:
在上述公式中,和分别表示学生和教师的中间特征,通常是指特征金字塔网络(Feature Pyramid Network,FPN)的特征。被引入来衡量和之间的距离,例如,在[FitNets]中使用均方误差(Mean Square Error),在[PKD]中使用皮尔逊相关系数(Pearson Correlation Coefficient)。
在上述描述中,表示区域选择原则,它在整个图像区域R中为每个位置生成一个权重。为了避免大幅度的噪声干扰模型的收敛性,不同的方法可能会使用不同的区域选择原则来选择用于蒸馏的有效区域,并平衡前景和背景样本的权重。最后,损失将通过整个中间特征上的的累积进行归一化处理。
特征蒸馏由于其出色的性能已经成为目标检测知识蒸馏方法的主流。然而,这可能强迫学生模型蒸馏教师模型中的不必要的噪声,这可能对最终结果产生负面影响。
2、预测模拟
与特征蒸馏不同,预测蒸馏旨在通过最小化教师和学生之间的预测差异来传递深层知识。其目标可以描述为:
其中,和分别是学生和教师检测Head生成的预测向量。区域选择原则根据不同的工作而有所不同。是计算和之间差异的损失函数,例如,用于分类的KL散度,用于回归的L1损失和LD。
正如在LD中描述的那样,由于预测具有明确的物理含义,预测蒸馏可以向学生提供特定任务的知识。然而,与特征蒸馏方法相比,预测蒸馏的性能较差限制了其应用。
2.2 Cross-Head Knowledge Distillation
正如在第1节中所述,作者观察到直接蒸馏教师的预测会面临目标冲突问题,这妨碍了预测蒸馏方法取得良好的性能。为了缓解这个问题,在本节中作者提出了一种新颖的交叉Head知识蒸馏(CrossKD)方法。总体框架如图3所示。与许多以前的预测蒸馏方法类似,作者的CrossKD在预测上进行蒸馏过程。不同的是,CrossKD将学生的中间特征传递给教师的检测Head,并生成交叉Head预测以进行蒸馏。
对于像RetinaNet这样的检测器,每个检测Head通常由一系列卷积层组成,表示为。为了简化起见,作者假设每个检测Head中总共有个卷积层(例如,在具有4个隐藏层和1个预测层的RetinaNet中是5个)。作者用来表示生成的特征图,表示的输入特征图。预测是由最后一个卷积层生成的。因此,对于给定的教师-学生对,教师和学生的预测可以分别表示为和。
除了教师和学生的原始预测之外,CrossKD还额外将学生的中间特征传递给教师的检测Head的第个卷积层,从而产生交叉Head预测。给定,作者建议使用交叉Head预测与教师的原始预测之间的KD loss来作为作者CrossKD的目标,具体描述如下:
其中,和分别表示区域选择原则和归一化因子。为了避免设计复杂的,作者遵循训练密集检测器的默认操作。在分类分支中,是一个恒定函数,其值为1。在回归分支中,是指示符,在前景区域生成1,在背景区域生成0。根据每个分支的不同任务(例如分类或回归),作者使用不同类型的来有效地将特定于任务的知识传递给学生。
通过进行CrossKD,检测损失和蒸馏损失分别应用于不同的分支。如图3所示,检测损失的梯度通过学生的整个Head传递,而蒸馏损失的梯度通过冻结的教师层传播到学生的潜在特征上,从而在启发式上增加了教师和学生之间的一致性。与直接对教师-学生对之间的预测进行调整相比,CrossKD允许学生的一部分检测Head仅与检测损失相关,从而更好地优化GT目标。作者在实验部分进行了定量分析。
2.3 优化目标
训练的总损失可以表示为检测损失和蒸馏损失的加权和,如下所示:
其中,和表示检测损失,它们是在学生的预测值、与对应的真实目标值、之间计算得出的。额外的CrossKD损失表示为和,它们是在交叉Head的预测值、与教师的预测值、之间进行计算的。
作者在不同的分支中应用不同的距离函数来传递特定任务的信息。在分类分支中,作者将教师预测的分类分数视为软标签,并直接使用GFL中提出的Quality Focal Loss(QFL)来拉近教师和学生之间的距离。
至于回归,密集检测器中主要有两种形式的回归方法:
- 第一种回归形式直接从Anchor框(例如RetinaNet,ATSS)或点(例如FCOS)回归边界框。在这种情况下,作者直接使用GIoU作为Dpred。
- 另一种情况是回归形式预测一个向量来表示框位置的分布(例如GFL),它包含比边界框表示的Dirac分布更丰富的信息。
为了有效地蒸馏位置分布知识,作者采用KL散度(如LD)来进行位置知识的转移。关于损失函数的更多细节请参考补充材料。
3、实验
3.1 Positions to apply CrossKD
3.2 CrossKD v.s. Feature Imitation
3.3 CrossKD v.s. Prediction Mimicking
3.4 CrossKD v.s. HEAD
3.5 CrossKD for Lightweight Detectors
3.6 Comparison with SOTA KD Methods
3.7 CrossKD on Different Detectors
4、参考
[1].CrossKD: Cross-Head Knowledge Distillation for Dense Object Detection.