助力目标检测涨点 | 可以这样把Vision Transformer知识蒸馏到CNN模型之中

简介: 助力目标检测涨点 | 可以这样把Vision Transformer知识蒸馏到CNN模型之中

资源受限的感知系统,例如边缘计算和面向机器人视觉,要求视觉模型在计算和内存使用方面既准确又轻量化。虽然知识蒸馏是增强轻量级分类模型性能的一种已验证策略,但将其应用于目标检测和实例分割等结构化输出仍然是一项复杂的任务,因为输出变化多样且知识蒸馏过程中涉及到的网络模块复杂。

本文提出了一种简单但出奇制胜的渐进式知识蒸馏方法,逐步将一组教师检测器的知识传递给给定的轻量级学生模型。为了从高精度但复杂的教师模型中蒸馏知识,作者构建了一系列教师来帮助学生逐渐适应。作者的渐进式策略可以轻松与现有的检测蒸馏机制结合使用,以在各种情况下始终最大化学生性能。

据作者所知,这是首次成功地将Transformer-based教师检测器的知识蒸馏到基于卷积的学生模型中,并在MS COCO基准测试中将RetinaNet基于ResNet-50的性能从36.5%提高到42.0% AP,将Mask RCNN的性能从38.2%提高到42.5% AP。

1、简介

在安全关键的实时应用中部署深度神经模型具有挑战性,特别是在资源有限的设备上,如自动驾驶汽车或虚拟/增强现实头戴式设备。这主要是由于巨大的计算复杂性和庞大的内存/存储需求。一种有效的策略是通过知识蒸馏来训练轻量级架构,该策略能够将大模型中学到的信息压缩到小模型中。

在目标检测领域实施知识蒸馏,尽管已经有相关工作,但仍然存在其独特的挑战,这些挑战源于复杂的任务输出:检测器使用多任务头(用于分类和框/Mask回归),产生可变长度的输出,这使得检测不同于单一输出分类任务。因此,通常不直接适用于检测的分类蒸馏方法,需要为检测开发专门的方法。

最近的工作主要考虑设计先进的蒸馏损失函数,用于将特征从教师传递给学生。然而,存在2个未解决的挑战:

  1. 模型之间的容量差距可能导致次优的蒸馏学生,即使使用了最强大的教师模型也是如此,这在优化学生模型的精确性和效率权衡时是不希望的。此外,当从基于Transformer的教师蒸馏知识到经典的基于卷积的学生时,这种架构差异可能会扩大师生差距(图1)。
  2. 目前的方法假定已选择了一个目标教师。然而,在现有的检测蒸馏文献中,忽略了教师模型选择的这种meta-level别优化。事实上,找到一组强大的教师候选人很容易,但在确定特定学生的最兼容教师之前,可能需要进行反复尝试。

为了解决这些挑战,作者提出了一个通过多教师渐进蒸馏(MTPD)来学习轻量级检测器的框架:

  1. 作者发现从多个教师进行顺序蒸馏明显改善了知识蒸馏,并弥合了不同架构引起的师生容量差距。如图1所示,即使存在巨大的架构差异,MTPD可以有效地将知识从基于Transformer的教师传递到基于卷积的学生中,而以前的方法则不能。
  2. 对于教师选择问题,作者为给定的学生和一组教师候选人设计了一种启发式算法,以自动确定在蒸馏过程中使用的教师顺序。该算法基于模型之间的表示相似性分析,不需要对特定蒸馏机制的先验知识。

总之,作者的主要贡献包括:

  • 提出了通过多教师渐进蒸馏(MTPD)来学习轻量级检测器的框架,这是一种简单但有效且通用的方法。作者开发了一种有原则的方法,可自动设计适用于给定学生的一系列适当的教师并逐渐蒸馏它。
  • MTPD是一种meta-level策略,可以轻松与以前在检测蒸馏中的努力相结合。作者在具有挑战性的MS COCO数据集上进行了全面的实证评估,观察到一致的增益,而不管蒸馏损失的复杂性如何。
  • MTPD学习了具有最先进精度的轻量级RetinaNet和Mask R-CNN,即使在异构的 Backbone 和输入分辨率设置中也是如此。也许最令人印象深刻的是,作者首次调查了从基于Transformer的教师检测器到卷积基学生的异构蒸馏,并发现渐进蒸馏是弥合它们差距的关键。
  • 作者经验证明,改进来自更好的泛化而不是更好的优化。从多个教师传递的知识使学生走向更平坦的最小值,从而帮助学生更好地泛化。

2、本文方法

在多教师渐进蒸馏(MTPD)中,作者提出逐渐将学生模型S与一组N个教师P = {Ti} N i=1 蒸馏。典型的目标检测器由四个模块组成:

  1. Backbone 网络,用于提取视觉特征,如ResNet和ResNeXt
  2. Neck 网络,用于从 Backbone 网络的不同阶段提取多级特征图,如FPN和Bi-FPN
  3. 用于两阶段检测器的可选区域建议网络(RPN)
  4. Head 网络,用于生成目标检测和分割的最终预测

作者将 Neck 网络的输出特征图表示为F Net,其中Net可以是学生模型S或教师集合P中的一个教师Ti ∈ P。具有FPN等 Neck 网络的特征图可以是多级的。MTPD是一种通用的目标检测蒸馏元策略,通过一系列教师逐渐学习学生。在这里,为了在不涉及复杂的蒸馏机制的情况下检验这一元策略,作者在第3.1节中引入了单个教师Ti的简单特征匹配蒸馏。然后作者在第3.2节中讨论了如何使用来自P的多个教师进行渐进蒸馏。

3.1. 初步:通过简单的特征匹配进行单教师蒸馏

为了通过蒸馏学习高效的学生检测器S,作者鼓励学生的特征表示与教师的特征表示相似。为此,作者最小化了教师和学生的特征表示之间的差异。简化起见,作者只是最小化了 和 之间的L2距离:

其中是用于匹配教师和学生特征图维度的函数。

作者将定义如下:

  • (同质情况)如果和S之间的通道数和空间分辨率都相同,则是一个恒等函数。
  • (异质情况)如果通道数不同但空间分辨率相同,则作者使用1×1的卷积作为。如果空间分辨率不同但通道数相同,则作者使用上采样层作为。如果通道数和空间分辨率都不同,则作者将卷积和上采样层组合为。请注意,映射只在训练时需要,因此不会增加推理的开销。

总的来说,作者的损失函数可以写成:

其中λ是平衡参数,Ldetect是基于GT标签的检测损失。与最先进的检测蒸馏方法相比,这种特征匹配蒸馏更简单,不需要运行教师模型的 Head 。作者的蒸馏损失如图2-Left所示。

3.2. 多教师的渐进蒸馏

知识蒸馏的总体目标是使学生模仿教师的输出,以便学生能够获得与教师相似的性能,然而,学生模型的容量有限,这使得学生很难从高度复杂的教师那里学到东西。

为了解决这个问题,使用多个教师网络为学生提供更多的监督。与以前的方法不同,以同时蒸馏所有特征信息的方式,作者建议逐步从多个教师中蒸馏基于特征的知识。作者的关键见解是,学生可以通过每次一个教师提供的知识更有效地蒸馏,而不是模仿所有特征信息的集合。这种渐进知识蒸馏方法可以被认为是由一系列教师提供的课程,如图2-Right所示。

关键问题是:在蒸馏学生时,什么是教师的最佳顺序O?

一种蛮力方法是搜索所有顺序并选择最佳的顺序(产生具有最高验证准确性的蒸馏学生)。然而,随着教师数量的增加,排列顺序的空间呈指数增长,这使得这种方法不切实际。因此,作者提出了一种基于每个模型学到的特征表示的相关性分析的原则性和高效的方法。

首先,作者量化了每对模型表示之间的不相似性,作为其容量差距的代理。表示(不)相似性已经被研究用于理解神经模型的学习能力。在作者的设置中,作者发现线性回归模型足以测量表示的不相似性。给定两个训练好的检测器A和B,作者冻结它们的参数,从而固定了特征表示。然后,作者学习一个线性映射,由每个特征级别的1×1卷积层实现,如第3.1节的异构情况中所指定的。被训练以最小化,因此它可以将A的特征转换为B的特征的近似值。

在训练完后,作者通过验证集上的来评估它,并将验证损失称为适应成本。这个指标可以作为两个模型之间容量差距的代理:当为零时,线性映射可以将A的特征转换为B的特征,并且没有来自B的额外知识。当较大时,将A的表示适应到B的表示更加困难。

请注意,适应成本是非对称的 - 将高容量模型的表示适应到低容量模型的表示相对较容易,反之亦然。

作者设计了一个启发式算法,称为反向贪心选择(BGS),以自动获取接近最优的蒸馏顺序O(请参见算法1中的伪代码和图3中的示例)。

假设要选择的最大教师数量由k限制(可以根据所需的训练时间进行任意决定),作者的目标是找到一个长度不超过k的教师索引序列α。作者按照相反的顺序构建教师顺序:最佳性能的教师被设置为最终目标α;在最终教师之前,作者使用另一个教师,该教师与最终教师的适应成本α最小,作为倒数第2个教师α。作者重复此过程以查找前面的教师,直到:

  1. 在尝试选择α时,作者发现从剩余教师到下一个教师α的传输成本都大于从训练好的学生到下一个教师α的传输成本;
  2. 达到给定的最大步数限制k。直观地说,得到的教师序列弥合了学生模型和教师之间的差距,并提供了一个越来越困难的课程。第4.1节和附录A演示了BGS的有效性。

作者的教师顺序设计方法高效且可扩展。实际上,主要的计算开销是一组微小线性映射(,用于基于FPN的检测器)。在作者的设置中,这个过程需要每个学生模型约3个GPU小时,仅占蒸馏所需数百个GPU小时的一小部分。如果添加更多的教师候选人,作者可以首先为每个教师生成特征图。然后,作者使用只占用10%-20% GPU小时的逐对线性映射来优化,以确保与教师数量相对线性的时间消耗增加。

由于MTPD是一种meta-level策略,它可以与以前的蒸馏机制的设计轻松集成,而无需太多努力。从学生检测器和一组候选教师开始,作者可以首先选择一组教师,并设计它们的蒸馏顺序。然后,作者将更高级的蒸馏机制依次应用于每个教师,以训练学生检测器,而不是简单的特征匹配损失。

3、实验

3.1. 寻找近乎最优的教师顺序

正如作者在第3.2节中讨论的,找到MTPD的最优教师顺序需要阶乘时间复杂度。为了获得接近最优的教师顺序,作者提出了启发式算法Backward Greedy Selection(BGS,伪代码如算法1所示)。

在本节中,作者验证了BGS是否接近最优。为了进行全面比较,作者将使用来自Teachers I-IV教师池的所有顺序来蒸馏Student I,作者使用了一个较少的训练预算:对于每个教师,作者只在MS COCO上训练学生3个Epoch。作者使用线性学习率调度,该调度在有限预算情况下证明是有效的。

作者首先测量了学生和教师模型之间的适应成本。成本图的可视化如图3所示。按照BGS,作者可以构建一系列教师。作者通过蒸馏学生的性能来比较BGS提供的教师顺序与所有其他顺序。

如表2所示,BGS建议的教师顺序在这种情况下一直接近最优。在接下来的章节中,作者使用BGS提供的顺序,而不是遍历所有可能的顺序。

有人可能会认为,如图3所示的BGS的贪心路径选择劣于全局优化算法。然而,作者发现BGS一直优于其他启发式方法,包括全局优化算法。事实上,后来的教师对学生的性能影响更为深远,因此作者需要从序列尾部贪心地选择教师。

3.2. 使用同质教师进行蒸馏

作者首先使用ResNet-50 Backbone (学生I和学生II)对RetinaNet和Mask R-CNN进行蒸馏。在这里,作者考虑通道数和特征图空间分辨率在学生和教师之间保持一致的同质教师。对于RetinaNet学生,作者仍然考虑Teachers I-IV的教师池,与第4.1节相同。

对于Mask R-CNN学生,作者不再使用Teacher I(学生本身)或Teacher II(单阶段教师的性能与学生相比没有明显提高)。为了弥补这一点,作者包括了Teacher V,可以视为DetectoRS Backbone / Neck 和Mask R-CNN Head 的混合模型。因此,Mask R-CNN的教师池包括Teachers III-V。

为了控制总的训练时间,作者将教师数量限制为2。作者从现成的('OTS')学生开始,然后使用2个教师进行顺序蒸馏,每个教师使用1×Schedule。总共,学生蒸馏了24个Epoch,训练时间相当于2×Schedule。

除了OTS学生之外,作者还与其他3个Baseline进行了比较:

  1. 使用较长的3×Schedule训练的学生,在目标检测库中通常支持,并且比1×、2×Schedule更强大;
  2. 使用2×Schedule直接蒸馏的学生最终的目标教师;
  3. 使用教师特征图的集合蒸馏的学生。检测器详细信息列在表1中。

如第4.1节所示,作者使用BGS来确定每个学生使用的教师顺序。对于RetinaNet学生,BGS建议教师顺序为III→IV。对于Mask R-CNN学生,BGS建议教师顺序为V→IV。

表3显示了在COCO上的蒸馏结果。Mask R-CNN蒸馏的其他结果、分析和消融研究请参见附录B。

总体性能:

作者蒸馏的学生模型(第4行和第9行)显著优于OTS学生模型(第1行和第5行)。RetinaNet的框AP从36.5%提高到39.9%(+3.4%)。Mask R-CNN的框AP从38.2%提高到41.4%(+3.2%),Mask R-CNN的面具AP从34.7%提高到37.3%(+2.6%)。经过渐进蒸馏,作者得到的Mask R-CNN检测器性能与HTC教师相媲美,但运行时间要少得多(51毫秒对181毫秒)。

与Baseline的比较:

首先,性能提升不仅仅来自更长的Schedule。作者蒸馏的学生模型(第4行和第9行)一直优于使用3×Schedule(第2行和第6行)训练的原始学生模型。其次,使用教师课程进行渐进蒸馏(第4行和第9行)比直接从强大的教师蒸馏(第3行和第7行)效果更好,即使总的训练时间相同。

此外,作者发现使用教师的顺序(第9行)而不是它们的集合(第8行)更有效。这表明整合多个教师的不同类型知识并不是一件简单的事情,而作者的渐进方法比同时从多个教师那里蒸馏更好。值得注意的是,作者的大型目标检测性能获得了最大的增益(两个模型的APL约提高了6%)。作者强调APL,因为在以效率为中心的实际应用(例如自动驾驶、机器人导航)中,检测附近较大的目标比其他目标更为关键。从现实角度来看,更好的APL显示了作者方法的更好适用性。

3.3. 使用异构教师进行蒸馏

为了验证MTPD的通用性,作者现在考虑一个更具挑战性的异构情景,即学生和教师具有不同的 Backbone 或输入分辨率。具体来说,学生III是一个ResNet-18 Mask R-CNN,经过ResNet-50教师的蒸馏;学生IV是一个带有降低输入分辨率的模型,经过使用较大输入分辨率进行训练的教师的蒸馏。结果总结在表4中,附录C中包括了更多的结果。

异构 Backbone :

学生III具有ResNet-18 Backbone ,运行时间约为其ResNet-50版本(教师I)的一半。作者发现,对于学生III来说,适当的蒸馏方案是使用教师I→V→IV的顺序(而不是集合),这显着改善了学生III的性能,将其从“OTS”模型提高到37.0%(+3.7%)的框AP,尤其是对于大型目标,APL从43.6%提高到50.0%(+6.4%)。

异构输入分辨率:

尽管可以将具有不同分辨率的输入馈送到大多数目标检测器中而不更改架构,但当训练和评估之间存在分辨率不匹配时,性能通常会下降。如果最终作者希望将检测器应用于低分辨率输入以实现快速推断,最好在训练过程中使用低分辨率输入。

另一方面,作者推测使用高分辨率输入的教师可能提供了可以帮助学生的更精细的细节。通过MTPD,作者研究了一个低分辨率学生通过一系列使用高分辨率输入的教师进行蒸馏的改进。作者将标准输入分辨率1333×800表示为1×,将降低分辨率333×200表示为0.25×。作者通过一系列Teacher I变种(0.5×→0.75×→1×)对学生IV(0.25×分辨率)进行蒸馏。

从表4中,作者可以看到MTPD带来的显著改进:框AP从25.8%提高到31.5%(+5.7%),面具AP从23.0%提高到28.2%(+5.2%)。

3.4. 对最新的蒸馏机制的通用性

作者的渐进蒸馏Meta策略,即使用一系列教师逐渐蒸馏学生,与选择每个教师的蒸馏机制无关。作者已经展示了MTPD可以提升简单的特征匹配蒸馏,本节中,作者将MTPD与最新的目标检测蒸馏机制结合起来,以进一步提高学生的准确性。

蒸馏协议:

作者使用检测器蒸馏中的3种最新方法对MTPD进行评估:CWD,FGD和MGD。

为了公平比较,作者使用与它们相同的教师-学生对:RetinaNet/ResNet-50和RetinaNet/ResNeXt-101是单阶段学生和最终教师。RepPoints/ResNet-50和RepPoints/ResNeXt-101是两阶段、Anchor-Free的学生和最终教师。Mask R-CNN/ResNet-50和Cascade Mask R-CNN/ResNeXt-101-DCN是两阶段、Anchor-Base的学生和最终教师。

在它们之间,作者插入一个中等容量的教师,逐渐蒸馏学生:第1对使用RetinaNet/ResNet-101,第2对使用RepPoints/ResNet-101,第3对使用Cascade Mask R-CNN/ResNet50-DCN。

同样为了公平起见,作者保持总的训练轮数相同。作者为每个教师设置1×Schedule,以便总的训练时间等同于2×,与以前的工作相同。

图4显示,MTPD一直提高学生的最终准确性。例如,FGD蒸馏的RetinaNet/ResNet-50的性能从40.7%提高到41.5%AP(+0.8%),而这个增益大于从FGD到MGD的机制提升(+0.3%)。作者几乎不费吹灰之力就为最新的检测蒸馏带来了性能提升。

接下来,作者研究如何进一步提高学生的性能。由于更好的计算效率,卷积型学生(而不是基于Transformer的学生)更受欢迎。同时,Swin Transformer可以充当比以前的卷积型教师更强大的教师。

然而,与卷积型教师相比,直接从这样的教师蒸馏不能提高学生的性能,即使作者使用最先进的MGD方法。例如,基于RetinaNet/Swin-Small(47.1%AP)的性能远远优于RetinaNet/ResNeXt-101(41.6%AP),但从这两者直接蒸馏得到的学生性能相同(41.0%AP)。

为了弥合ResNet-50学生和Swin-Small教师之间的架构差异和容量差距,作者可以利用一个中间的Swin-Tiny教师。

如表5所示,MTPD带来了最好的学生:基于ResNet-50的RetinaNet的性能提高到了42.0%AP,Mask R-CNN的性能提高到了42.5%AP。作者还成功地从卷积型教师那里蒸馏出了基于Transformer的学生。

3.5. 性能提升的解析:泛化还是优化?

作者已经展示了作者的蒸馏学生在验证数据上的准确性明显优于现成的学生。正如图5a所示,蒸馏学生的验证准确性在蒸馏过程中逐渐提高,并在与没有教师的学生相比实现了更高的值。然后自然而然地出现了一个问题——为什么蒸馏有帮助呢?有两种可能的假设:

  1. 改进的优化:蒸馏有助于优化过程,导致更好的局部最小值;
  2. 改进的泛化:蒸馏有助于学生泛化到未见数据。

改进的优化通常通过更好的模型、更低的训练损失和更高的验证准确性来体现,这正是Mask R-CNN、HTC和DetectoRS的情况。因此,人们可能认为蒸馏是以同样的方式工作的。然而,作者的调查表明相反的情况——MTPD增加了验证准确性和训练损失,从而有效地减小了泛化差距。

在图5中,作者比较了原始的RetinaNet模型和经过蒸馏的学生,它们具有相同的架构、相同的延迟,都是在相同的数据上进行训练的,但监督方式不同(只有GT标签与额外的知识蒸馏)。为了消除学习率变化的影响,作者使用3×Schedule训练原始学生,并在与蒸馏学生相同的时间重新启动学习率。有趣的是,尽管蒸馏可以提高学生的验证性能,但经过蒸馏的学生的训练检测损失高于原始学生。这表明蒸馏不帮助优化过程找到具有更低训练损失的局部最小值,而是增强了学生模型的泛化能力。

为了进一步支持这一观察,作者还可视化了局部损失Landscape。与原始学生相比,蒸馏学生具有更平坦的损失Landscape(图5d与图5c相比)。如机器学习文献中广泛认为的,平坦的极小值可以带来更好的泛化。图5中所示的观察结果是针对RetinaNet进行的,但作者在其他学生身上也有类似的观察结果。

总之,知识蒸馏,即强制学生模仿教师的特征,可以被视为一种隐式正则化,有助于学生抵抗过拟合并实现更好的泛化。

4、参考

[1]. Learning Lightweight Object Detectors via Multi-Teacher Progressive Distillation.

相关实践学习
阿里云表格存储使用教程
表格存储(Table Store)是构建在阿里云飞天分布式系统之上的分布式NoSQL数据存储服务,根据99.99%的高可用以及11个9的数据可靠性的标准设计。表格存储通过数据分片和负载均衡技术,实现数据规模与访问并发上的无缝扩展,提供海量结构化数据的存储和实时访问。 产品详情:https://www.aliyun.com/product/ots
相关文章
|
1月前
|
机器学习/深度学习
大模型开发:解释卷积神经网络(CNN)是如何在图像识别任务中工作的。
**CNN图像识别摘要:** CNN通过卷积层提取图像局部特征,池化层减小尺寸并保持关键信息,全连接层整合特征,最后用Softmax等分类器进行识别。自动学习与空间处理能力使其在图像识别中表现出色。
24 2
|
2月前
|
机器学习/深度学习 编解码
LeViT-UNet:transformer 编码器和CNN解码器的有效整合
LeViT-UNet:transformer 编码器和CNN解码器的有效整合
44 0
|
1月前
|
机器学习/深度学习 PyTorch TensorFlow
python实现深度学习模型(如:卷积神经网络)。
【2月更文挑战第14天】【2月更文挑战第38篇】实现深度学习模型(如:卷积神经网络)。
|
2月前
|
机器学习/深度学习 并行计算 算法
模型压缩部署神技 | CNN与Transformer通用,让ConvNeXt精度几乎无损,速度提升40%
模型压缩部署神技 | CNN与Transformer通用,让ConvNeXt精度几乎无损,速度提升40%
60 0
|
2月前
|
机器学习/深度学习 编解码 数据可视化
RecursiveDet | 超越Sparse RCNN,完全端到端目标检测的新曙光
RecursiveDet | 超越Sparse RCNN,完全端到端目标检测的新曙光
57 0
|
2月前
|
机器学习/深度学习 编解码 测试技术
超强Trick | 如何设计一个比Transformer更强的CNN Backbone
超强Trick | 如何设计一个比Transformer更强的CNN Backbone
42 0
|
3月前
|
机器学习/深度学习 网络架构 计算机视觉
CNN经典网络模型之GoogleNet论文解读
GoogleNet,也被称为Inception-v1,是由Google团队在2014年提出的一种深度卷积神经网络架构,专门用于图像分类和特征提取任务。它在ILSVRC(ImageNet Large Scale Visual Recognition Challenge)比赛中取得了优异的成绩,引入了"Inception"模块,这是一种多尺度卷积核并行结构,可以增强网络对不同尺度特征的感知能力。
|
2月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
1月前
|
机器学习/深度学习 算法 数据库
基于CNN卷积网络的MNIST手写数字识别matlab仿真,CNN编程实现不使用matlab工具箱
基于CNN卷积网络的MNIST手写数字识别matlab仿真,CNN编程实现不使用matlab工具箱
|
4月前
|
机器学习/深度学习
CNN卷积神经网络手写数字集实现对抗样本与对抗攻击实战(附源码)
CNN卷积神经网络手写数字集实现对抗样本与对抗攻击实战(附源码)
39 0

热门文章

最新文章