DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,5000CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!

5d173548bc683fa17396fd973c758755.png

基于Transformer的检测器(DETR)由于其稀疏的训练范式和去除后处理操作而引起了极大的关注,但是这个庞大的模型在计算上非常耗时,并且很难在实际应用中部署。为了解决这一问题,可以使用知识蒸馏(KD)来通过构建通用的师生学习框架来压缩庞大的模型。


与传统的CNN检测器不同,DETR将目标检测视为一个集合预测问题,导致在蒸馏过程中教师和学生之间的关系不明确。在本文中提出了DETRDistill,这是一种专门用于DETR系列的知识蒸馏方法。作者首先探索了一种稀疏匹配模式,该模式具有渐进的逐阶段实例蒸馏。考虑到不同DETR中采用的不同注意力机制,作者提出了注意力无关的特征提取模块,以克服传统特征模仿的无效性。最后,为了充分利用教师的中间特征,作者引入了教师辅助匹配提取,它使用教师的目标查询和匹配结果,为一个小组提供额外的指导。


大量实验表明,DETRDistill在各种竞争性DETR方法上实现了显著的改进,而不会在推断阶段引入额外的消耗。这也是第一次探索DETR型检测器通用蒸馏方法的系统研究。

1、简介

目标检测旨在从输入图像中定位和分类视觉目标。在早期的工作中,这项任务通常是通过结合卷积神经网络(CNN)来处理输入图像的区域特征来实现的,其中包括一系列归纳偏差,例如Anchor、标签分配和NMS。最近,已经提出了像DETR这样的基于Transformer的目标检测器,其中检测被视为集合预测任务,并且可以端到端地训练。它们极大地简化了目标检测管道,用户无需对手工设计组件进行繁琐的调整,例如Anchor尺寸和比例。

尽管基于Transformer的检测器已经达到了最先进的精度,但它们仍然存在昂贵的计算问题,使得它们难以在实时应用中部署。为了获得快速准确的检测器,经常需要知识蒸馏(KD)。它使开发者能够在一个复杂但强大的教师模型的帮助下训练一个轻而快速的模型。这样便可以得到一个推理速度快、精度高的目标检测器。

知识蒸馏在传统的基于CNN的物体检测器中取得了巨大的成功。然而,为基于Transformer的检测器设计一个好的KD策略是非常重要的。首先,基于CNN的KD算法很难直接应用于基于Transformer的模型,因为它们的架构不同。具体而言,在基于Transformer的检测器中,目标信息主要被编码在目标查询向量中,而这种信息主要存在于基于CNN的检测器中的图像特征图中,从而导致这两个检测器家族中显著不同的特征分布。

e79dea73400a6f8c6eed6d4b881c9a85.png

例如,在图2中比较了两个检测器的主干特征激活:ATSS(基于CNN)和AdaMixer(基于Transformer)。可以观察到ATSS的特征激活主要围绕目标,而背景区域也在AdaMixer中被激活。因此,为基于CNN的检测器提出的加权策略不适用于基于Transformer的检测器。另一方面,基于Transformer的检测器的集合预测设计也使知识提取具有挑战性:不同的查询目标匹配使得难以将实例级信息从教师传递到学生模型。此外,这些DETR变体中不同的交叉注意力设计进一步阻碍了制定统一的KD策略。

为了解决上述挑战,作者提出了DETRDistill,这是一个专门为基于Transformer的检测器设计的知识蒸馏框架。该工作是探索DETR型检测器知识蒸馏的第一项系统研究。DETRistill基于3个组件:

  1. 渐进式实例蒸馏:通过渐进式逐阶段实例蒸馏建立目标查询之间的预测匹配,以逐步向学生传递有用的知识。同时,监督教师和学生之间的内容查询关系。
  2. 注意力无关的特征提取:与传统的特征模仿不同,使用从每个解码器层聚合的内容查询来恢复特征级提取的注意力掩码,使其与多种交叉注意力机制无关。
  3. 教师辅助匹配提取:为了充分利用由教师模型训练的查询嵌入和相应的二分图匹配结果,作者将其视为学生提供更多固定匹配样本的额外训练流。

98e5af21dd160239a1e3cf5168328d59.png

作者将所提出的DETRDistill应用于广泛使用的MS COCO基准。实验结果验证了该方法的有效性和泛化能力。DETRDistill将最先进的性能存档,与之前的目标检测KD方法相比具有巨大优势。如图1所示,DETRDistill在3个竞争性的基于Transformer的检测器上将学生成绩提高了2.2 AP、2.5 AP和2.4 AP,这甚至超过了各自的教师模型。

2、相关工作

2.1、Transformer-based目标检测

随着Transformer在自然语言处理中的出色表现,研究人员也开始探索Transformer结构在视觉任务中的应用。然而,DETR训练过程效率极低,因此许多后续工作都试图加速收敛。一项工作试图重新设计注意力机制。例如,Dai等人提出了可变形DETR,它通过仅与参考点周围的可变采样点特征交互来构建稀疏注意力机制。SMCA在限制交叉关注之前引入了高斯。

AdaMixer设计了一种新的无编码器的自适应3D特征采样策略,然后将通道和空间维度的采样特征与自适应权重融合。另一行工作重新思考了查询的含义。Meng等人认为,DETR依赖交叉注意力中的内容嵌入来定位目标末端是无效的,因此提出了将查询解耦为内容部分和位置部分。Anchor-DETR直接将查询的2D参考点作为其位置嵌入来引导注意力。DAB-DETR将除了位置之外的宽度和高度信息引入到注意力机制中,以对不同比例的目标进行建模。

2.2、目标检测中的知识蒸馏

知识蒸馏是一种常用的模型压缩方法。《Distilling the knowledge in a neural network》首次提出了这一概念,并将其应用于图像分类领域。他们认为,与一个热编码相比,教师输出的软标签包含类别间相似性的“暗知识”,这有助于模型的泛化。注意力转移将注意力集中在特征图上,并通过缩小教师和学生的注意力分布而不是提取输出逻辑来转移知识。

FitNet建议通过隐藏层模仿教师模型的中级提示。《Learning efficient object detection models with knowledge distillation》首次应用知识蒸馏来解决多目标检测问题。《Mimicking very efficient network for object detection》认为背景区域会引入噪声,并提出提取RPN采样的区域。DeFeat分别提取了前景和背景。FGD分别在焦点区域和全局特征关系方面模仿了教师模型。LD将软标签蒸馏扩展到位置回归,使学生符合教师的边界预测分布。MGD使用掩模图像建模(MIM)将模拟任务转换为图像生成任务。

3、方法

在本节中,首先回顾了DETR的基本架构,然后介绍了提出的DETRDistill的具体实现,它由3个组件组成:

  1. 渐进式蒸馏
  2. 注意力无关的特征蒸馏
  3. 教师辅助匹配蒸馏

e998883a9b31907074de683f28f17d19.png

图3说明了DETRDistill的总体架构。

3.1、回顾DETR

DETR是一个端到端目标检测器,包括CNN主干、可学习查询嵌入、Transformer编码器和解码器。给定图像1675317275432.png,CNN主干提取其空间特征,然后Transformer编码器(某些变体不需要编码器)增强这些特征。具有几个更新的功能1675317290400.png,查询嵌入1675317314562.png被馈送到Transformer解码器以产生检测结果。细化目标查询的计算方法如下:

9c5e38bcb1d6b1f64428ff1295452a5d.png

其中d是特征维度,N是固定的查询数量。1675317398835.png是查询的表示特征。X表示采样特征集,1675317421021.png元素。索引注意力头部,1675317409029.png是可学习的权重,1675317434077.png表示每个查询和key之间的注意力权重。

3.2、DETRDistill

1、Progressive Instance Distillation

最常见的知识蒸馏策略之一是将教师的预测软标签直接传递给学生模型进行学习。然而,预测结果的稀疏性和查询预测的不稳定性使得DETR难以将教师的结果与学生的预测有序地对应起来。


为了实现这一目标,作者利用匈牙利算法来解决DETR稀疏预测的匹配问题。形式上,设1675317452708.png表示教师和学生的预测结果,符合1675317480119.png,其中1675317467774.png分别是教师和学生解码器查询的固定数量。


预测结果的数量由查询的数量确定。1675317558665.png组成,分别表示类别和位置投影。类似地,1675317575729.png组成。由于知识蒸馏的性质,M通常大于或等于N。然后,可以以最低成本搜索教师和学生的输出之间的排列:

00178e88c52ffd93b98471f74294033c.png

1675317598745.png是pair-wise匹配成本,其定义为:

3b2840b97b638ac1b94d04dbab14cb3f.png

其中1675317616105.png是KL损失,1675317635844.png是L1损失和GIoU损失的组合。

然而,上述将教师的知识完全传授给学生的策略可能是次优的。受学习曲线和知识回顾机制的激励,作者认为一个人应该在不同的年龄学习不同层次的知识,当前阶段正确的学习方向是下一阶段成功学习的保证。知识蒸馏的过程类似于上述情况,在此基础上,作者提出了渐进式实例蒸馏。


作者希望学生模型的每个阶段都获得不同的知识水平,为下一阶段的顺利学习奠定坚实的基础。DETR的解码器部分通常包含K个阶段(K>1),当前阶段的预测是前一阶段的细化。因此,可以将每个阶段视为一个学习阶段。没有将教师的最终预测作为一个蒸馏目标,而是仔细地调整教师和学生在每个阶段的输出,让模型逐步学习不同层次的知识,这大大降低了学习难度,提高了更好的表现。


形式上,让1675317654372.png分别表示教师第k阶段的类别和位置,1675317668942.png表示学生阶段。根据等式(3),可以使用教师和学生的第k阶段的输出来获得相应的排列1675317679612.png。因此,第k阶段的蒸馏损失可写成:

167983b51295f005465d850a97e7352f.png

其中1675317696396.png是二元交叉熵损失,1675317712084.png是GIoU损失,1675317722741.png是L1损失,α、β和γ是重加权因子。此外,考虑到一对一实例蒸馏是不完整的,实例关系是不可或缺的,作者采用欧氏距离来表示教师中目标查询q之间的关系信息,并将其传递给具有L1损失的学生,类似于GID:

494c708609c081ad4832b2b6874b6215.png

2、Attention-agnostic Feature Distillation

基于CNN的检测器通过将多个卷积层连接到主干的特征来直接输出预测结果。由于卷积网络保持固定的方形感受野,检测头获取的特征来自目标区域的均匀插值采样,可以基于GT轻松地划分。然而,DETR中解码器的主要操作是自注意力和交叉注意力。自注意力是查询之间的交互模式,通常被理解为防止重复预测的结构。交叉注意力是从特征中提取和聚集目标信息的主要方式。


DETR变体之间的差异主要在于特征采样策略和注意力的生成方法。原始DETR使用冗余多头注意力机制:通过基于查询与特征图的每个位置之间的余弦相似度计算注意力,经过加权融合,可以获得具有更丰富语义信息的目标特征。可变形DETR提出了一种类似于可变形卷积网络的可学习采样方法。它可以在不受固定区域限制的情况下自适应地对参考点附近的稀疏特征进行采样,并通过神经网络生成注意力权重。AdaMixer的自适应混合方法允许在整个图像空间中跨特征层进行采样。值得注意的是,这种灵活的采样方法是DETR的最大优点之一。


考虑到不同的注意力机制,作者没有设计复杂的策略来划分特征,而是适应不同DETR的注意力,如图3(b)所示。模拟教师空间特征的典型方式可以通过以下方式计算:

71b6a65090bae519e4e83ce75d5931df.png

其中1675317780079.png分别表示教师和学生的特征,1675317792824.png是reshape维度的适应层。H、 W表示特征的高度和宽度,d表示通道。1675317821231.png表示所选和分离区域的soft mask。例如,FitNet将mask视为填充了1的矩阵。Wang等人提出了通过GT计算的mask。Sun等人利用高斯mask覆盖GT。作者试图通过计算查询和特征之间的相似度的交互式方式来恢复注意力,以生成soft mask1675317792824.png 。将教师的全部提问作为1675317833024.png,形状为(M,H×W)的纯注意力mask可以通过以下方式获得:

65a04758ed18f9179ac727c8bb4110bc.png

其中M表示来自教师解码器的解码器查询数,表示Sigmoid函数,f是投影层,教师的特征图1675317852176.png被展平为形状1675317916634.png。然后,通过1675317936915.png将与每个查询对应的所有mask合并为单个mask。

a5b7a73d052b8d20ea10849c2e5ff1f9.png

然而,从经验上发现,这种标准蒸馏方法效果不佳,因为并非来自教师的所有目标查询都应被同等地视为有价值的线索,并且来自查询的大多数预测都是否定的,如图4所示。为此,通过考虑分类和定位来探索教师目标查询的质量。具体而言,质量分数由以下公式计算:

409aed8bd7620d5d94c3912397cb8152.png

其中1675317953268.png分别表示来自查询的分类分数和预测框。质量分数作为一个指标,用来指导哪个查询应该得到更多的重视。因此,最终的mask引导蒸馏可以写成:

98a759cefcadd279c23e53c90f67cb79.png

3、Teacher-assisted Assignment Distillation

当初始训练学生模型时,噪声填充的查询嵌入会导致不稳定的二分图匹配。如DN-DETR中所述,查询通常与不同时期的不同目标相匹配,这使得优化变得不明确和不稳定。在知识蒸馏的设置中,通常拥有经过训练的教师模型的所有参数,包括查询嵌入,其中包含关于目标及其分配结果的足够信息。利用这些信息来提高优化方向的稳定性是直观的。


基于这一动机,作者提出了教师辅助标签分配。设1675317973220.png分别表示教师的查询嵌入和分配排列。在维度由投影层对齐后,教师的查询被输入到学生模型中,该投影层由DETR的原始损失和分配的1675317988800.png进行监督,记录为1675318021085.png。值得注意的是,教师辅助的匹配蒸馏和学生自己的原始训练过程共享网络参数,而无需在它们之间进行任何信息交互。这使得无论学生查询如何变化,教师总是有稳定的查询和匹配来指导解码器的优化。

4、Overall loss

综上所述,训练学生DETR,总损失如下:

e564bb57a0f4538607e533974e2e3673.png

其中1675318033998.png是DETR的原始损失。蒸馏方法遵循常见的DETR范式,可以很容易地应用于各种检测器。

4、实验

4.1、消融实验

1、主要结果

73f9caa4bf5d3fa0e517973852c4e81c.png

2、分析渐进式蒸馏

04bcf6b608cae9b2486fb9b15c1868a3.png

3、分类和回归分支样本消融

3cdfe82c957ee8433c2e4eee782f0f36.png

4、特色分区规划策略分析

9577c1ab71fbcfde667dfd2030685736.png0f9a6410cae1a9226f438e2d9a34744c.png

4.2、COCO

bd5d18e777e7259127d82474896adbe1.png

4.3、Distilling to Lightweight Backbones

86a221ba65cb1a795b2868a061c27d09.png

4.4、Self-Distillation

f1e370c40a0454b8c0d7b8f808813081.png

5、参考

[1].DETRDistill: A Universal Knowledge Distillation Framework for DETR-families.

6、推荐阅读

目标检测系列 | 无NMS的端到端目标检测模型,超越OneNet,FCOS等SOTA!

目标检测落地技能 | 拥挤目标检测你是如何解决的呢?改进Copy-Paste解决拥挤问题!

多目标跟踪新SOTA | TransTrack改进版本来啦,模型减小58.73%,复杂性降低78.72%

相关文章
|
5月前
|
算法 计算机视觉
YOLOv8改进 | 损失函数篇 | 最新ShapeIoU、InnerShapeIoU损失助力细节涨点
YOLOv8改进 | 损失函数篇 | 最新ShapeIoU、InnerShapeIoU损失助力细节涨点
312 2
|
5月前
|
机器学习/深度学习 算法 固态存储
最强DETR+YOLO | 三阶段的端到端目标检测器的DEYOv2正式来啦,性能炸裂!!!
最强DETR+YOLO | 三阶段的端到端目标检测器的DEYOv2正式来啦,性能炸裂!!!
200 0
|
5月前
|
机器学习/深度学习 算法 计算机视觉
YOLOv5改进 | 损失函数篇 | 最新ShapeIoU、InnerShapeIoU损失助力细节涨点
YOLOv5改进 | 损失函数篇 | 最新ShapeIoU、InnerShapeIoU损失助力细节涨点
271 1
|
3月前
|
机器学习/深度学习 计算机视觉
YOLOv10实战:红外小目标实战 | 多头检测器提升小目标检测精度
本文改进: 在进行目标检测时,小目标会出现漏检或检测效果不佳等问题。YOLOv10有3个检测头,能够多尺度对目标进行检测,但对微小目标检测可能存在检测能力不佳的现象,因此添加一个微小物体的检测头,能够大量涨点,map提升明显; 多头检测器提升小目标检测精度,1)mAP50从0.666提升至0.677
631 3
|
5月前
|
机器学习/深度学习
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
338 1
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
|
5月前
|
机器学习/深度学习 边缘计算 自动驾驶
【初探GSConv】轻量化卷积层直接带来的小目标检测增益!摘录于自动驾驶汽车检测器的架构
【初探GSConv】轻量化卷积层直接带来的小目标检测增益!摘录于自动驾驶汽车检测器的架构
337 0
【初探GSConv】轻量化卷积层直接带来的小目标检测增益!摘录于自动驾驶汽车检测器的架构
|
5月前
|
机器学习/深度学习 编解码 数据可视化
即插即用 | 高效多尺度注意力模型成为YOLOv5改进的小帮手
即插即用 | 高效多尺度注意力模型成为YOLOv5改进的小帮手
316 1
|
5月前
|
机器学习/深度学习 固态存储 算法
目标检测的福音 | 如果特征融合还用FPN/PAFPN?YOLOX+GFPN融合直接起飞,再涨2个点
目标检测的福音 | 如果特征融合还用FPN/PAFPN?YOLOX+GFPN融合直接起飞,再涨2个点
223 0
|
5月前
|
机器学习/深度学习 编解码
YOLOv5改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
YOLOv5改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
249 0
|
计算机视觉
DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!(二)
DETR也需要学习 | DETR-Distill模型蒸馏让DETR系类模型持续发光发热!!!(二)
234 0
下一篇
无影云桌面