即插即用 | 通过自适应聚类Transformer来提升DERT目标检测器的速度(文末附论文下载)(一)

简介: 即插即用 | 通过自适应聚类Transformer来提升DERT目标检测器的速度(文末附论文下载)(一)

1、简介


DERT中使用Transformer进行端到端目标检测,并实现与Faster-RCNN等两阶段目标检测可比的性能。但是,由于高分辨率的图像输入,DETR需要大量的计算资源来进行训练和推理。为了降低高分辨率输入的计算成本,本文提出了一种新型的变体,称为自适应聚类Transformer(ACT)

ACT使用局部敏感哈希(LSH)自适应地对查询特征进行聚类,并使用原型键交互在查询键交互附近进行聚类。ACT可以将自注意力的二次O(N2)复杂度降低到O(NK),其中K是每层原型的数量。ACT可以是嵌入式模块,无需任何训练即可代替原始的自注意力模块。


2、设计动机


2.1、编码器注意力冗余

DETR中,每个位置的特征将使用注意力机制自适应地从其他空间位置收集特征信息。而语义上相似和空间上相近的特征会产生相似的注意力图,反之亦然。如下图Figure 1:

P0和P1的注意力图相似且包含冗余,而距离较远的P0和P3的注意图则完全不同。自注意力的冗余促使行为选择具有代表性的原型,并将原型特征更新传播给最邻近的位置。

2.2、编码器特征多样性

随着编码器的深入,各个特性将会相似,因为每个特性会互相收集信息。为了验证这个假设,文中计算了训练集中前100张照片中每一层特征到中心的平均距离。

如Figure 2所示,随着Layer的深入,特征距离会降低。这种观察得出特征会在每一层之间的分布自适应地确定原型的数量,而不是静态的聚类中心的数量。


3、论文方法


作者为了将类相似的查询特征组合在一起进而解决编码器注意力冗余问题,最开始想法是使用K-means对所有图像使用预定义的中心数量来聚类查询特性。实验发现K-means聚类会显著降低预先训练好的DETR的性能。同时编码特征多样性促使设计一种自适应聚类算法,可以根据每幅图像和每一层的特征分布对特征进行聚类。然后作者选择了一种Multi-round Exact Euclidean Locality Sensitivity Hashing (E2LSH)算法,该算法能够实现查询特征感知分布的聚类。

3.1、DETR主要框架

下图显示DETR的3个阶段:

3.1.1 编码阶段

在编码器中,使用imagenet预训练的ResNet模型从输入图像中提取二维特征。位置编码模块使用不同频率的正弦和余弦函数对空间信息进行编码。DETR将2D特性扁平化,并用位置编码补充它们,并将它们传递给6层transformer编码器。编码器的每一层结构相同,包括8头自注意模块和FFN模块。

3.1.2 解码阶段

解码器将一小部分固定数量并且已学习好的位置嵌入(称为对象查询)作为输入,并另外处理编码器的输出。该解码器也有6层,每层具有相同的配置:包括8-head自注意力模块、8-head co-attention模块和1个FFN模块。

3.1.3 预测阶段

DETR将解码器的每个输出传递给共享前馈网络,该网络预测检测(类和边界框)或无对象类。

3.2、Adaptive Clustering Transformer

3.2.1 确定哈希原型

因为LSH是一个解决最近邻搜索问题的不错的解决方案,所以使用Locality Sensitivity hash(LSH)自适应地聚合那些欧氏距离较小的查询。通过控制哈希函数参数让所有的向量距离小于的以一个大于p的概率落入相同的hash bucket。

image.png

文章选择精确欧几里得局部性敏感哈希(E2LSH)作为哈希函数:

image.png

其中是哈希函数,是超参数,是随机变量,且满足服从分布服从分布;本文则应用L轮LSH来增加结果的可信度:

image.png

下图Figure 4显示了哈希函数的原理。每个哈希函数可以看作是一组具有随机法向量和偏移量的并行超平面。超参数r控制超平面的间距。r越大,间隔越大。此外,L哈希函数将空间划分为多个单元格,落入同一个单元格的向量将获得相同的哈希值。显然,欧氏距离越近,两个向量落在同一个单元格中的概率就越大。

image.png

为了获得原型,首先计算每个查询的哈希值。然后,将具有相同哈希值的查询分组到一个集群中,该集群的原型是这些查询的中心。具体将定义为查询,定义为原型,其中是集群的数量。设表示所属的聚类指标。第个聚类的原型:

image.png

3.2.2 估计注意力输出

在原型确定后,每一组查询都由一个原型表示。因此,只需要计算原型和键之间的注意力映射。然后获得每个原型的目标向量,并将其广播到每个原始查询。于是便得到了对注意力输出的估计。与精确的注意量计算相比,减少了计算的复杂性到,其中C为原型数,其大于N,将自适应确定。

具体定义K为键,V为值。通过下面的方程得到了注意力输出的估计数:

image.png

其中,softmax函数逐行应用,表示所属的聚类指数。

3.2.3 误差控制

对于和两个向量,令。假设和落在同一个哈希桶中的概率是那么就可以证明:

image.png

其中,是均值为0,方差为1的正态分布的概率密度函数。

ACT独立应用L轮哈希,和 的碰撞概率为。显然,与c呈单调递减,因此,当两个查询的距离小于。对于给定的置信度可以通过调整超参数L和r来控制查询与其原型在同一簇内的距离,估计误差会随着L的增加和r的减少而减小。

3.2.4 多任务知识蒸馏

虽然ACT可以在不经过再训练的情况下降低DETR的计算复杂度,但多任务知识蒸馏(MTKD)可以通过更少的微调时间进一步提高ACT,并在性能和计算之间取得更好的平衡。

image.png

如下图所示图像特征将首先通过预训练好的CNN主干进行提取。提取的特征将通过ACT与Tranformer并联。为了在ACT和原始的transformer之间实现无缝切换,MTKD将用于约束训练。训练损失表示为:

image.png

式中,其中为Ground Truth, 为ACT预测的Bounding Box,, Y2为DETR预测的Bounding Box和完整预测。是Ground Truth与DETR预测之间的原始loss。是知识蒸馏的损失,它使ACT和DETR预测的Bounding Box之间的L2距离最小化。

训练损失的目的是结合完全预测和近似预测之间的知识转移对原始Transformer进行训练,实现ACT和transformer之间的无缝切换。知识转换器包括区域分类和回归蒸馏。回归分支比分类分支对由ACT引入的近似误差更敏感。因此,只转移Bounding Box回归分支的知识。实验观察到通过只Bounding Box回归分支的训练收敛速度要快得多。

相关文章
|
3月前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
70 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
4月前
|
机器学习/深度学习 搜索推荐
CIKM 2024:LLM蒸馏到GNN,性能提升6.2%!Emory提出大模型蒸馏到文本图
【9月更文挑战第17天】在CIKM 2024会议上,Emory大学的研究人员提出了一种创新框架,将大型语言模型(LLM)的知识蒸馏到图神经网络(GNN)中,以克服文本图(TAGs)学习中的数据稀缺问题。该方法通过LLM生成文本推理,并训练解释器模型理解这些推理,再用学生模型模仿此过程。实验显示,在四个数据集上性能平均提升了6.2%,但依赖于LLM的质量和高性能。论文链接:https://arxiv.org/pdf/2402.12022
103 7
|
8月前
|
机器学习/深度学习 人工智能 计算机视觉
CVPR 2023 | AdaAD: 通过自适应对抗蒸馏提高轻量级模型的鲁棒性
CVPR 2023 | AdaAD: 通过自适应对抗蒸馏提高轻量级模型的鲁棒性
264 0
|
8月前
|
机器学习/深度学习 自然语言处理 算法
从滑动窗口到YOLO、Transformer:目标检测的技术革新
从滑动窗口到YOLO、Transformer:目标检测的技术革新
199 0
|
数据可视化 数据挖掘
即插即用 | 通过自适应聚类Transformer来提升DERT目标检测器的速度(文末附论文下载)(二)
即插即用 | 通过自适应聚类Transformer来提升DERT目标检测器的速度(文末附论文下载)(二)
257 0
|
数据可视化 Go 计算机视觉
FastPillars实时3D目标检测 | 完美融合PointPillar、YOLO以及RepVGG的思想(二)
FastPillars实时3D目标检测 | 完美融合PointPillar、YOLO以及RepVGG的思想(二)
231 0
|
机器学习/深度学习 存储 自动驾驶
FastPillars实时3D目标检测 | 完美融合PointPillar、YOLO以及RepVGG的思想(一)
FastPillars实时3D目标检测 | 完美融合PointPillar、YOLO以及RepVGG的思想(一)
1468 0
|
机器学习/深度学习 计算机视觉 索引
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(一)
579 0
|
计算机视觉
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
目标检测无痛涨点新方法 | DRKD蒸馏让ResNet18拥有ResNet50的精度(二)
154 0
|
机器学习/深度学习 自动驾驶 计算机视觉
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点(一)
目标检测提升技巧 | 结构化蒸馏一行代码让目标检测轻松无痛涨点(一)
165 0