DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!(一)

简介: DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!(一)

基于Transformer的检测器(DETR)由于其稀疏的训练范式和去除后处理操作而引起了极大的关注,但是这个庞大的模型在计算上非常耗时,并且很难在实际应用中部署。为了解决这一问题,可以使用知识蒸馏(KD)来通过构建通用的师生学习框架来压缩庞大的模型。

与传统的CNN检测器不同,DETR将目标检测视为一个集合预测问题,导致在蒸馏过程中教师和学生之间的关系不明确。在本文中提出了DETRDistill,这是一种专门用于DETR系列的知识蒸馏方法。作者首先探索了一种稀疏匹配模式,该模式具有渐进的逐阶段实例蒸馏。考虑到不同DETR中采用的不同注意力机制,作者提出了注意力无关的特征提取模块,以克服传统特征模仿的无效性。最后,为了充分利用教师的中间特征,作者引入了教师辅助匹配提取,它使用教师的目标查询和匹配结果,为一个小组提供额外的指导。

大量实验表明,DETRDistill在各种竞争性DETR方法上实现了显著的改进,而不会在推断阶段引入额外的消耗。这也是第一次探索DETR型检测器通用蒸馏方法的系统研究。


1、简介


目标检测旨在从输入图像中定位和分类视觉目标。在早期的工作中,这项任务通常是通过结合卷积神经网络(CNN)来处理输入图像的区域特征来实现的,其中包括一系列归纳偏差,例如Anchor、标签分配和NMS。最近,已经提出了像DETR这样的基于Transformer的目标检测器,其中检测被视为集合预测任务,并且可以端到端地训练。它们极大地简化了目标检测管道,用户无需对手工设计组件进行繁琐的调整,例如Anchor尺寸和比例。

尽管基于Transformer的检测器已经达到了最先进的精度,但它们仍然存在昂贵的计算问题,使得它们难以在实时应用中部署。为了获得快速准确的检测器,经常需要知识蒸馏(KD)。它使开发者能够在一个复杂但强大的教师模型的帮助下训练一个轻而快速的模型。这样便可以得到一个推理速度快、精度高的目标检测器。

知识蒸馏在传统的基于CNN的物体检测器中取得了巨大的成功。然而,为基于Transformer的检测器设计一个好的KD策略是非常重要的。首先,基于CNN的KD算法很难直接应用于基于Transformer的模型,因为它们的架构不同。具体而言,在基于Transformer的检测器中,目标信息主要被编码在目标查询向量中,而这种信息主要存在于基于CNN的检测器中的图像特征图中,从而导致这两个检测器家族中显著不同的特征分布。

image.png

例如,在图2中比较了两个检测器的主干特征激活:ATSS(基于CNN)和AdaMixer(基于Transformer)。可以观察到ATSS的特征激活主要围绕目标,而背景区域也在AdaMixer中被激活。因此,为基于CNN的检测器提出的加权策略不适用于基于Transformer的检测器。另一方面,基于Transformer的检测器的集合预测设计也使知识提取具有挑战性:不同的查询目标匹配使得难以将实例级信息从教师传递到学生模型。此外,这些DETR变体中不同的交叉注意力设计进一步阻碍了制定统一的KD策略。

为了解决上述挑战,作者提出了DETRDistill,这是一个专门为基于Transformer的检测器设计的知识蒸馏框架。该工作是探索DETR型检测器知识蒸馏的第一项系统研究。DETRistill基于3个组件:

  1. 渐进式实例蒸馏:通过渐进式逐阶段实例蒸馏建立目标查询之间的预测匹配,以逐步向学生传递有用的知识。同时,监督教师和学生之间的内容查询关系。
  2. 注意力无关的特征提取:与传统的特征模仿不同,使用从每个解码器层聚合的内容查询来恢复特征级提取的注意力掩码,使其与多种交叉注意力机制无关。
  3. 教师辅助匹配提取:为了充分利用由教师模型训练的查询嵌入和相应的二分图匹配结果,作者将其视为学生提供更多固定匹配样本的额外训练流。

image.png

作者将所提出的DETRDistill应用于广泛使用的MS COCO基准。实验结果验证了该方法的有效性和泛化能力。DETRDistill将最先进的性能存档,与之前的目标检测KD方法相比具有巨大优势。如图1所示,DETRDistill在3个竞争性的基于Transformer的检测器上将学生成绩提高了2.2 AP、2.5 AP和2.4 AP,这甚至超过了各自的教师模型。


2、相关工作


2.1、Transformer-based目标检测

随着Transformer在自然语言处理中的出色表现,研究人员也开始探索Transformer结构在视觉任务中的应用。然而,DETR训练过程效率极低,因此许多后续工作都试图加速收敛。一项工作试图重新设计注意力机制。例如,Dai等人提出了可变形DETR,它通过仅与参考点周围的可变采样点特征交互来构建稀疏注意力机制。SMCA在限制交叉关注之前引入了高斯。

AdaMixer设计了一种新的无编码器的自适应3D特征采样策略,然后将通道和空间维度的采样特征与自适应权重融合。另一行工作重新思考了查询的含义。Meng等人认为,DETR依赖交叉注意力中的内容嵌入来定位目标末端是无效的,因此提出了将查询解耦为内容部分和位置部分。Anchor-DETR直接将查询的2D参考点作为其位置嵌入来引导注意力。DAB-DETR将除了位置之外的宽度和高度信息引入到注意力机制中,以对不同比例的目标进行建模。

2.2、目标检测中的知识蒸馏

知识蒸馏是一种常用的模型压缩方法。《Distilling the knowledge in a neural network》首次提出了这一概念,并将其应用于图像分类领域。他们认为,与一个热编码相比,教师输出的软标签包含类别间相似性的“暗知识”,这有助于模型的泛化。注意力转移将注意力集中在特征图上,并通过缩小教师和学生的注意力分布而不是提取输出逻辑来转移知识。

FitNet建议通过隐藏层模仿教师模型的中级提示。《Learning efficient object detection models with knowledge distillation》首次应用知识蒸馏来解决多目标检测问题。《Mimicking very efficient network for object detection》认为背景区域会引入噪声,并提出提取RPN采样的区域。DeFeat分别提取了前景和背景。FGD分别在焦点区域和全局特征关系方面模仿了教师模型。LD将软标签蒸馏扩展到位置回归,使学生符合教师的边界预测分布。MGD使用掩模图像建模(MIM)将模拟任务转换为图像生成任务。


3、方法


在本节中,首先回顾了DETR的基本架构,然后介绍了提出的DETRDistill的具体实现,它由3个组件组成:

  1. 渐进式蒸馏
  2. 注意力无关的特征蒸馏
  3. 教师辅助匹配蒸馏

图3说明了DETRDistill的总体架构。

3.1、回顾DETR

DETR是一个端到端目标检测器,包括CNN主干、可学习查询嵌入、Transformer编码器和解码器。给定图像,CNN主干提取其空间特征,然后Transformer编码器(某些变体不需要编码器)增强这些特征。具有几个更新的功能,查询嵌入被馈送到Transformer解码器以产生检测结果。细化目标查询的计算方法如下:

其中是特征维度,是固定的查询数量。是查询的表示特征。表示采样特征集,是key元素。索引注意力头部,和是可学习的权重,表示每个查询和key之间的注意力权重。

3.2、DETRDistill

1、Progressive Instance Distillation

最常见的知识蒸馏策略之一是将教师的预测软标签直接传递给学生模型进行学习。然而,预测结果的稀疏性和查询预测的不稳定性使得DETR难以将教师的结果与学生的预测有序地对应起来。

为了实现这一目标,作者利用匈牙利算法来解决DETR稀疏预测的匹配问题。形式上,设和表示教师和学生的预测结果,符合和,其中和分别是教师和学生解码器查询的固定数量。

预测结果的数量由查询的数量确定。由和组成,分别表示类别和位置投影。类似地,由和组成。由于知识蒸馏的性质,通常大于或等于。然后,可以以最低成本搜索教师和学生的输出之间的排列:

是pair-wise匹配成本,其定义为:

image.png

其中是KL损失,是L1损失和GIoU损失的组合。

然而,上述将教师的知识完全传授给学生的策略可能是次优的。受学习曲线和知识回顾机制的激励,作者认为一个人应该在不同的年龄学习不同层次的知识,当前阶段正确的学习方向是下一阶段成功学习的保证。知识蒸馏的过程类似于上述情况,在此基础上,作者提出了渐进式实例蒸馏。

作者希望学生模型的每个阶段都获得不同的知识水平,为下一阶段的顺利学习奠定坚实的基础。DETR的解码器部分通常包含K个阶段(K>1),当前阶段的预测是前一阶段的细化。因此,可以将每个阶段视为一个学习阶段。没有将教师的最终预测作为一个蒸馏目标,而是仔细地调整教师和学生在每个阶段的输出,让模型逐步学习不同层次的知识,这大大降低了学习难度,提高了更好的表现。

形式上,让和分别表示教师第k阶段的类别和位置,和表示学生阶段。根据等式(3),可以使用教师和学生的第k阶段的输出来获得相应的排列。因此,第k阶段的蒸馏损失可写成:

image.png

其中是二元交叉熵损失,是GIoU损失,是L1损失,α、β和γ是重加权因子。此外,考虑到一对一实例蒸馏是不完整的,实例关系是不可或缺的,作者采用欧氏距离来表示教师中目标查询q之间的关系信息,并将其传递给具有L1损失的学生,类似于GID:

image.png

2、Attention-agnostic Feature Distillation

基于CNN的检测器通过将多个卷积层连接到主干的特征来直接输出预测结果。由于卷积网络保持固定的方形感受野,检测头获取的特征来自目标区域的均匀插值采样,可以基于GT轻松地划分。然而,DETR中解码器的主要操作是自注意力和交叉注意力。自注意力是查询之间的交互模式,通常被理解为防止重复预测的结构。交叉注意力是从特征中提取和聚集目标信息的主要方式。

DETR变体之间的差异主要在于特征采样策略和注意力的生成方法。原始DETR使用冗余多头注意力机制:通过基于查询与特征图的每个位置之间的余弦相似度计算注意力,经过加权融合,可以获得具有更丰富语义信息的目标特征。可变形DETR提出了一种类似于可变形卷积网络的可学习采样方法。它可以在不受固定区域限制的情况下自适应地对参考点附近的稀疏特征进行采样,并通过神经网络生成注意力权重。AdaMixer的自适应混合方法允许在整个图像空间中跨特征层进行采样。值得注意的是,这种灵活的采样方法是DETR的最大优点之一。

考虑到不同的注意力机制,作者没有设计复杂的策略来划分特征,而是适应不同DETR的注意力,如图3(b)所示。模拟教师空间特征的典型方式可以通过以下方式计算:

image.png

其中和分别表示教师和学生的特征,是reshape维度的适应层。H、 W表示特征的高度和宽度,d表示通道。表示所选和分离区域的soft mask。例如,FitNet将mask视为填充了1的矩阵。Wang等人提出了通过GT计算的mask。Sun等人利用高斯mask覆盖GT。作者试图通过计算查询和特征之间的相似度的交互式方式来恢复注意力,以生成soft mask 。将教师的全部提问作为,形状为(M,H×W)的纯注意力mask可以通过以下方式获得:

其中M表示来自教师解码器的解码器查询数,表示Sigmoid函数,f是投影层,教师的特征图被展平为形状(d,H×W)。然后,通过将与每个查询对应的所有mask合并为单个mask。

image.png

然而,从经验上发现,这种标准蒸馏方法效果不佳,因为并非来自教师的所有目标查询都应被同等地视为有价值的线索,并且来自查询的大多数预测都是否定的,如图4所示。为此,通过考虑分类和定位来探索教师目标查询的质量。具体而言,质量分数由以下公式计算:

image.png

其中和分别表示来自查询的分类分数和预测框。质量分数作为一个指标,用来指导哪个查询应该得到更多的重视。因此,最终的mask引导蒸馏可以写成:

image.png

3、Teacher-assisted Assignment Distillation

当初始训练学生模型时,噪声填充的查询嵌入会导致不稳定的二分图匹配。如DN-DETR中所述,查询通常与不同时期的不同目标相匹配,这使得优化变得不明确和不稳定。在知识蒸馏的设置中,通常拥有经过训练的教师模型的所有参数,包括查询嵌入,其中包含关于目标及其分配结果的足够信息。利用这些信息来提高优化方向的稳定性是直观的。

基于这一动机,作者提出了教师辅助标签分配。设和分别表示教师的查询嵌入和分配排列。在维度由投影层对齐后,教师的查询被输入到学生模型中,该投影层由DETR的原始损失和分配的进行监督,记录为。值得注意的是,教师辅助的匹配蒸馏和学生自己的原始训练过程共享网络参数,而无需在它们之间进行任何信息交互。这使得无论学生查询如何变化,教师总是有稳定的查询和匹配来指导解码器的优化。

4、Overall loss

综上所述,训练学生DETR,总损失如下:

其中是DETR的原始损失。蒸馏方法遵循常见的DETR范式,可以很容易地应用于各种检测器。

相关文章
|
机器学习/深度学习 数据采集 算法
四足动物模型控制中的模型自适应神经网络
翻译:《Mode-Adaptive Neural Networks for Quadruped Motion Control》
100 0
|
计算机视觉
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
279 0
|
8月前
|
机器学习/深度学习
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
433 1
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
|
8月前
|
机器学习/深度学习 数据挖掘 测试技术
DETR即插即用 | RefineBox进一步细化DETR家族的检测框,无痛涨点
DETR即插即用 | RefineBox进一步细化DETR家族的检测框,无痛涨点
421 1
|
8月前
|
机器学习/深度学习 人工智能 计算机视觉
CVPR 2023 | AdaAD: 通过自适应对抗蒸馏提高轻量级模型的鲁棒性
CVPR 2023 | AdaAD: 通过自适应对抗蒸馏提高轻量级模型的鲁棒性
267 0
|
8月前
|
机器学习/深度学习 编解码
YOLOv5改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
YOLOv5改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
349 0
|
计算机视觉
DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!(二)
DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!(二)
270 0
|
机器学习/深度学习 传感器 编解码
CenterFormer | CenterNet思想究竟有多少花样?看CenterFormer在3D检测全新SOTA
CenterFormer | CenterNet思想究竟有多少花样?看CenterFormer在3D检测全新SOTA
163 0
|
机器学习/深度学习 计算机视觉 索引
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
583 0
|
计算机视觉
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
156 0

热门文章

最新文章