详细解读 | 如何让你的DETR目标检测模型快速收敛(一)

简介: 详细解读 | 如何让你的DETR目标检测模型快速收敛(一)

1简介


最近发展起来的DETR方法将transformer编解码器体系结构应用于目标检测并取得了很好的性能。在本文中,作者解决了训练收敛速度慢这一关键问题,并提出了一种conditional cross-attention mechanism用于快速训练DETR。作者动机是cross-attention在DETR中高度依赖content embeddings定位的4端和预测框,这增加了对高质量content embedding的需求进而增加了训练的难度。

本文方法被命名为条件DETR,从decoder embedding中学习了一个conditional spatial query用于decoder multi-head cross-attention。其好处是,通过conditional spatial query每个cross-attention head能够关注包含不同区域的band(例如,一个目标端点或目标box内的一个区域)。这缩小了目标分类和box回归的不同区域定位的空间范围,从而减轻了对content embedding的依赖,减轻了训练。

实验结果表明,对于Backbone R50和R101,条件DETR收敛速度快6.7倍;对于backboone DC5-R50和DC5-R101,条件DETR收敛速度快10倍。


2背景


DETR方法将transformer应用于目标检测取得了良好的性能。它有效地消除了许多手工制作组件的需要,包括NMS和Anchor生成。

DETR方法在训练上收敛缓慢,需要500个epoch才能取得良好的效果。Deformable DETR通过使用高分辨率和多尺度编码器将global dense attention(self-attention和cross-attention)替换为deformable attention来解决这个问题。相反,本文仍然使用global dense attention并提出了一个改进的 decoder cross-attention mechanism以加速训练收敛的过程。

本文方法的动机是高度依赖content embeddings和spatial embeddings in cross-attention。实验结果表明,如果从第2解码器层去除key和query中的位置嵌入,只使用key和query中的content embeddings,检测AP略有下降(1%)。

image.png

图1

图1(第2行)显示了50个epoch训练的DETR cross-attention的spatial attention weight maps。可以看到,4个映射中有2个没有正确地突出对应端点的波段,因此在缩小内容查询的空间范围以精确定位端点方面很弱。其原因是:

  1. spatial query:即目标query,只给出general attention weight map,而没有利用具体的图像信息;
  2. 由于训练时间短content embeddings不够强,不能很好地匹配spatial key,因为它们也用于匹配content key。这增加了对高质量content embeddings的依赖性,从而增加了训练难度。

本文提出了一种有条件的DETR方法,该方法从之前对应的解码器输出嵌入中学习每个query的条件spatial embedding,形成decoder multi-head cross-attention的条件spatial query。通过将用于回归目标框的信息映射到嵌入空间来预测条件spatial query。


3条件DETR


3.1 方法概览

该方法采用端到端目标检测器(detection transformer, DETR),无需生成NMS或Anchor即可一次性预测所有目标。该体系结构由CNN Backbone、transformer encoder、transformer decoder、目标分类器和边界框位置预测器组成。transformer encoder的目的是改进CNN Backbone的content embeddings输出。它是由多个编码器层组成的堆栈,其中每一层主要由self-attention层和feed-forward层组成。

image.png

图3

transformer decoder是一堆decoder layer。每个decoder layer如图3所示,由3个主要层组成:

  1. self-attention layer:用于去除重复预测,执行前一解码器层输出的嵌入之间的交互,用于类和box的预测;
  2. cross-attention layer:该层聚合编码器输出的embedding以细化解码器embedding改进类和box预测;
  3. feed-forward layer

Box回归

从每个decoder embedding中预测一个候选框,如下所示:

image.png

这里,是decoder embedding。是一个4维矢量,由框的中心、框的宽度和框的高度组成。Sigmoid()用于将预测b归一化到范围[0,1]。FFN()的目的是预测非规范化框。在原始DETR中为(0,0),s为参考点的非归一化二维坐标。在本方法中,作者考虑2个选择:将参考点s作为每个候选框预测的参数学习,或者从相应的目标query中生成它。

类别预测

每个候选框的分类score也通过FNN预测:

Main work

cross-attention mechanism的目的是定位不同的区域(用于box检测的4个端点和box内用于目标分类的区域)并聚合相应的嵌入。本文提出了一种条件cross-attention mechanism,通过引入conditional spatial query来提高定位能力和加速训练的收敛过程。

3.2 DETR Decoder Cross-Attention

DETR解码器cross-attention mechanism有3个输入:query、key和value。每个key都是通过添加一个content key (编码器的content embedding输出)和一个spatial key (对应的标准化2D坐标的positional embedding)形成的。该value是由编码器输出的content embedding(与content key相同)形成的。

在原始的DETR方法中,每个query是通过添加content query (decoder self-attention embedding)和spatial query (即object query )形成的。在实现中,有N=300个object queries,相应地有N个query,每个query在一个解码器层输出一个候选检测结果。

attention weight是基于query与key的点积:

image.png

3.3 Conditional Cross-Attention

提出的Conditional Cross-Attention将解码器self-attention输出的content query 和spatial query 串联起来形成query。因此,同理,key由content key 和spatial key 拼接而成。

cross-attention weight由content attention weight和spatial attention weight两部分组成。这两个权重来自两个点积,content和spatial点积:

image.png

与原来的DETR Cross-Attention不同,本文所提的机制分离了content query和spatial query的角色,使spatial query和content query分别关注spatial和content的attention weight。

另外一个重要的任务是从前一个解码器层的embedding 计算spatial query 。首先识别出不同区域的空间信息是由解码器 embedding 和参考点这两个因素共同决定的。然后展示了如何将它们映射到embedding space形成query ,使spatial位于key的2D坐标映射到的同一空间。

解码器embedding包含不同区域相对于参考点的位移。式1中的box预测过程包括2个步骤:

  1. 对非归一化空间中的参考点进行预测;
  2. 将预测框归一化到范围[0,1];

步骤(1)表示decoder embedding f包含了构成方框的4个端点相对于非归一化空间中的参考点s的位移。这意味着,无论是embedding f还是参考点s,都需要确定不同区域、4个极值以及预测分类评分的区域的空间信息。

Conditional spatial query prediction

通过embedding f和参考点s预测条件空间查询,

image.png

以便与key的标准化2D坐标映射到的位置空间对齐。这个过程如图3的灰色框区所示。

这里将参考点归一化,然后将其映射到256维正弦位置嵌入,方法与key的位置嵌入相同:

image.png

然后通过可学习线性投影+ReLU+可学习线性投影组成的FFN将解码器embedding f中包含的位移信息映射到同一空间中的线性投影:

image.png

conditional spatial query通过转换embedding空间中的参考点来计算:

image.png

作者选择简单和计算效率高的对角矩阵。256个对角线元素被表示为一个向量。conditional spatial query通过逐元素乘法计算:

image.png

Multi-head cross-attention

与DETR一样作者采用标准的Multi-head cross-attention机制。目标检测通常需要隐式或显式定位目标的4个端点以实现精确的box回归,并定位目标区域以实现精确的目标分类。multi-head mechanism有利于解决定位任务的纠缠问题。

作者通过将query、key和value M=8次投影到低维的线性投影来执行multi-head parallel attentions。spatial和content query(key)分别以不同的线性投影投影到每个head。value投影与原始的DETR相同,仅用于content。

3.4 可视化分析

图4可视化了每个attention weight maps:

  • spatial attention weight maps
  • content attention weight maps
  • combined attention weight maps

对spatial dot-products 、 content dot-products 和combined dot-products 进行softmax normalized。图中显示了8个map中的5个其他3个是重复的,对应于底部和顶部的端点,以及目标框内的一个小区域。

image.png

图4

可以看到,每个head的spatial attention weight maps能够定位一个不同的区域(包含一个极点的区域或物体box内的区域)。有趣的是,每个spatial attention weight maps对应的一个极点突出了一个空间带,该空间带与目标框的相应边缘重叠。目标框内区域的另一个spatial attention weight maps仅仅突出显示了一个小区域,该区域的表示可能已经编码了足够的目标分类信息。

content attention weight maps还突出了分散的区域。空间和内容映射的组合过滤掉了其他高亮部分,并保留了极端高亮部分以实现精确的box回归。

Comparison to DETR

图1显示了条件式DETR(第1行)和经过50个epoch训练的原始DETR(第2行)的spatial attention weight maps。本文方法的映射是通过spatial key和query之间的dot的softmax normalized来计算的:

可以看出spatial attention weight maps准确定位了不同的区域。相比之下,原始的DETR中包含50个epoch的map不能准确定位2个极点,而500个训练epoch(第3行)使得content query更强,从而实现了精确定位。这意味着学习content query 作为2个角色(同时匹配content key和spatial key)是非常困难的,因此需要更多的训练epoch。

分析

图4所示的spatial attention weight maps暗示用于形成spatial query的conditional spatial query至少有2种效果:

  1. 将突出显示的位置转换为4个端点和目标框内的位置:有趣的是,突出显示的位置在目标框内的空间分布相似;
  2. 缩放顶端亮点的空间扩展:大目标的空间扩展大,小目标的空间扩展小。

这2种效果是在spatial embedding space中通过T/ps变换实现的(通过cross-attention中包含的独立于图像的线性投影进一步分离,并分布到每个head)。这说明变换T不仅包含前面讨论的位移,还包含目标尺度。

相关文章
|
机器学习/深度学习 存储 算法
神经网络中的量化与蒸馏
本文将深入研究深度学习中精简模型的技术:量化和蒸馏
102 0
|
4月前
|
计算机视觉 异构计算
【YOLOv8改进-SPPF】 AIFI : 基于注意力的尺度内特征交互,保持高准确度的同时减少计算成本
YOLOv8专栏介绍了该系列目标检测框架的最新改进与实战应用。文章提出RT-DETR,首个实时端到端检测器,解决了速度与精度问题。通过高效混合编码器和不确定性最小化查询选择,RT-DETR在COCO数据集上实现高AP并保持高帧率,优于其他YOLO版本。论文和代码已开源。核心代码展示了AIFI Transformer层,用于位置嵌入。更多详情见[YOLOv8专栏](https://blog.csdn.net/shangyanaf/category_12303415.html)。
|
机器学习/深度学习 编解码 PyTorch
DenseNet、MobileNet、DPN…你都掌握了吗?一文总结图像分类必备经典模型(二)
DenseNet、MobileNet、DPN…你都掌握了吗?一文总结图像分类必备经典模型(二)
203 0
|
机器学习/深度学习 算法
基于贝叶斯优化CNN-LSTM混合神经网络预测(Matlab代码实现)
基于贝叶斯优化CNN-LSTM混合神经网络预测(Matlab代码实现)
215 0
|
机器学习/深度学习 传感器 算法
【CNN回归预测】基于贝叶斯优化卷积神经网络BO-CNN实现数据回归预测附matlab代码
【CNN回归预测】基于贝叶斯优化卷积神经网络BO-CNN实现数据回归预测附matlab代码
|
计算机视觉
【目标检测出】评价指标
【目标检测出】评价指标
144 0
|
计算机视觉
详细解读 | 如何让你的DETR目标检测模型快速收敛(二)
详细解读 | 如何让你的DETR目标检测模型快速收敛(二)
252 0
|
计算机视觉
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
137 0
|
机器学习/深度学习 计算机视觉 索引
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
538 0
下一篇
无影云桌面