沈春华团队最新 | SegViT v2对SegViT进行全面升级,让基于ViT的分割模型更轻更强

本文涉及的产品
简介: 沈春华团队最新 | SegViT v2对SegViT进行全面升级,让基于ViT的分割模型更轻更强

作者探索了使用编码器-解码器框架的普通Vision Transformer(ViTs)进行语义分割的能力,并介绍了SegViTv2。在我们的工作中,我们实现了具有ViT Backbone 中固有的全局注意力力机制的解码器,并提出了轻量级的Attention-to-Mask(ATM)模块,该模块可以有效地将全局注意力力图转换为语义Mask,以获得高质量的分割结果。

本文的解码器在各种ViT Backbone 中的性能优于最常用的解码器UpperNet,同时仅消耗约5%的计算成本。对于编码器,解决了基于ViT的编码器中相对较高的计算成本的问题,并提出了一种Shrunk++结构,该结构结合了边缘感知的基于Query的下采样(EQD)和基于Query的上采样(QU)模块。Shrunk++结构将编码器的计算成本降低了50%,同时保持了有竞争力的性能。

此外,由于基于ViT架构的灵活性,SegViT可以在持续学习的情况下轻松扩展到语义分割,实现几乎零遗忘。实验表明提出的SegViT在ADE20k、COCO-Stuff-10k和PASCAL Context数据集等3个流行的基准上优于最近的分割方法。

代码:https://github.com/zbwxp/SegVit

1、简介

语义分割是计算机Vision 中的一项关键任务,需要对输入图像进行精确的像素级分类。在最先进的技术中广泛使用的传统方法,如全卷积网络(FCN),使用深度卷积神经网络(ConvNet)作为编码器或基本模型和分割解码器来生成密集预测。先前的工作旨在通过增强上下文信息或结合多尺度信息来提高性能,利用ConvNet架构固有的多尺度和层次属性。

Vision Transformer(ViT)的出现提供了一种范式转变,成为许多计算机Vision 任务的强大支柱。与ConvNet基本模型不同,ViT在保留特征图分辨率的同时,保留了简单和非层次的架构。为了方便地利用现有的分割解码器进行密集预测,如U-Net或DeepLab,最近的基于Transformer的方法,包括Swin-Transformer和PVT,已经开发了分层ViT来提取分层特征表示。然而,由于分层架构和普通架构之间的差异,如空间下采样,修改原始ViT结构需要从头开始训练网络,而不是使用现成的普通ViT检查点。

改变简单的ViT结构消除了利用从Vision 语言预训练方法(如CLIP、BEiT、BEiT-v2、MVP和COTS)获得的丰富表示的潜力。因此,为原始ViT结构开发有效的解码器以利用这些强大的表示具有明显的优势。先前的工作,如UPerNet和DPT,主要注意力层次特征图,而忽略了普通Vision Transformer的独特特征。

因此,如图1所示,这些方法导致计算要求很高的操作,性能改进有限。一些作品(如SETR或Segmenter)的最新趋势旨在开发专门为Plain ViT架构量身定制的解码器。然而,这些设计通常代表了从传统的基于卷积的解码器导出的每像素分类技术的简单扩展。例如,SETR的解码器使用卷积序列和双线性上采样来逐渐增加ViT提取的特征图。然后,它将朴素的MLP应用于提取的特征,以执行逐像素分类,从而隔离像素周围的相邻上下文。当前的按像素分类解码器设计在为每个像素分配标签时忽略了上下文学习的重要性。

包括Transformer在内的深度网络中的另一个普遍问题是“灾难性遗忘”,即模型在先前学习的任务上的性能随着学习新任务而恶化。这一限制对深度分割模型在动态现实世界环境中的应用提出了重大挑战。最近,在大规模数据上预训练的基础模型的快速发展激发了研究人员对研究其在各种下游任务中的可转移性的兴趣。这些模型能够提取强大的广义表示,这导致人们对探索其对新类和任务的可扩展性越来越感兴趣,同时保留先前学习的知识表示。

受这些挑战的激励,本文旨在探索如何在不需要分层 Backbone 重新设计的情况下,普通Vision Transformer能够更有效地执行语义分割任务。随着自监督和多模态预训练的不断发展,预计普通Vision Transformer将学习增强的Vision 表示。因此,用于密集任务的解码器被期望更灵活和有效地适应这些表示。

鉴于这些研究空白,本文提出了SegViTv2——一种新颖、高效的分割网络,它具有简单的Vision Transformer,并表现出对遗忘的鲁棒性。介绍了一种新颖的注意力力Mask(ATM)模块,它作为SegViT解码器的轻量级组件运行。利用交叉注意力力学习的非线性,提出的ATM使用可学习的类Token作为Query,以精确定位与每个类具有高度兼容性的空间位置。主张隶属于特定类别的区域具有与相应类别Token相对应的实质相似性 Value 。

如图2所示,ATM生成了一个有意义的相似性图,该图强调了对“桌子”和“椅子”类别具有强烈亲和力的区域。通过简单地实现Sigmoid运算,可以将这些相似性映射转换为Mask级别的预测。Mask的计算与像素数量成线性比例,这是一个可以忽略不计的成本,可以集成到任何 Backbone 中以提高分割精度。在这个高效的ATM模块的基础上,作者还提出了一种新的语义分割范式,该范式利用了普通ViT的成本效益结构,称为SegViT。在这个范例中,多个ATM模块被部署在不同的层,以提取不同规模的分割Mask。最后的预测是从这些层导出的输出的总和。

为了减轻普通Vision Transformer(ViTs)的计算负担,引入了“Shrunk”和“Shrunk++”结构,它们结合了基于Query的下采样(QD)和基于Query的上采样(QU)。所提出的QD采用2x2最近邻下采样技术来获得更稀疏的Token网格,从而减少了注意力力计算中涉及的Token数量。

此外,作者将QD扩展到基于边缘感知Query的下采样(EQD)。EQD选择性地保留位于模板边缘的Token,因为它们拥有更多的判别信息。因此,QU恢复了对象同质体内被丢弃的Token,重建了对精确密集预测至关重要的高分辨率特征。通过将“Shrunk”结构与作为解码器的ATM模块集成,实现了高达50%的计算减少,同时保持了有竞争力的性能水平。

作者进一步将SegViT框架的应用扩展到持续学习。利用基础模型所学的强大而广义的表示,本文旨在研究在不忘记所学知识的情况下将基础模型扩展到新类和新任务的能力。最近的连续语义分割(CSS)技术旨在重放旧数据或从以前的模型中提取知识,以缓解模型差异。这些方法需要对负责旧任务的参数进行微调,这可能会破坏先前学习的解决方案,导致遗忘。

相反,提出的SegViT支持在不侵犯先前获得的知识的情况下学习新课程。作者努力建立一个无遗忘的SegViT框架,该框架包含一个新的ATM模块,专门用于新任务,同时将所有旧参数保持在冻结状态。因此,所提出的SegViT有可能实际上消除遗忘问题。

主要贡献总结如下:

  • 介绍了注意力力到Mask(ATM)解码器模块,这是一种有效的语义分割工具。首次利用注意力力图中的空间信息为每个类别生成Mask预测,提出了一种新的语义分割范式。
  • 提出了适用于任何普通ViT Backbone 的Shrunk结构,它减轻了非分层ViT固有的高计算,同时保持了竞争性能,如图1所示是第一部利用边缘信息减少和恢复Token以实现高效计算的作品。Shrunk++版本的SegViT v2在ADE20K数据集上测试,实现了55.7%的mIoU,计算成本为308.8 GFLOP,与原始SegViT(637.9 GFLOP)相比减少了约50%。
  • 提出了一种新的SegViT架构,该架构能够在没有遗忘的情况下持续学习。据所知是第一个试图完全冻结旧类的所有参数的工作,从而几乎消除了灾难性遗忘的问题。

2、本文方法

在本节中将首先介绍提出的SegViT模型的总体架构。然后,将继续讨论“Shrunk”架构,该架构旨在降低模型的总体计算成本。此外,将深入研究连续语义分割设置,并调整SegViT模型框架以与此设置无缝一致。

2.1、Overall SegViT architecture

SegViT包括负责特征提取的基于ViT的编码器和用于学习分割图的解码器。在编码器方面,设计了“Shrunk”结构,以降低与普通ViT相关的计算成本。关于解码器介绍了一种新的轻量级模块,称为注意力Mask(ATM)。该模块生成表示为M的类特定Mask和表示为P的类预测,它们确定图像中特定类的存在。来自ATM模块堆栈的Mask输出被组合,然后与类预测相乘,以获得最终的分段输出。图3展示了提出的SegViT的总体架构。

1、Encoder

给定输入图像,Plain Transformer Backbone 将其reshape为Token序列,其中,P是patch-size,C是通道数。为了捕获位置信息,添加了与大小相同的可学习位置嵌入。随后,Token序列经历m个Transformer层以产生输出。每个层的输出Token定义为。在像ViT这样的普通Vision Transformer的情况下,不涉及额外的模块,并且每个层的Token数量保持不变。然而,普通ViT的计算成本可能高得令人望而却步。为了解决这个问题,引入了Shrunk结构,它能够开发一种高效的基于ViT的编码器。

2、Decoder

注意力Mask(ATM)。交叉注意力可以描述为两个Token序列之间的映射,表示为。在本文的情况下,定义了2个Token序列:,长度N等于类的数量,。为了实现交叉注意力,将线性Transformer应用于每个Token序列,从而产生Query(Q)、Key(K)和Value(V)表示。该过程由等式(1)描述。

相似度图是通过计算Query和Key表示之间的点积来计算的。

根据缩放点积注意力机制,相似度图和注意力图计算如下:

其中是比例因子,等于key的尺寸。

相似性映射的形状由2个Token序列N和L的长度确定。注意力机制通过执行V的加权和来更新G,其中,在沿着L维度应用softmax函数之后,从相似性映射导出权重。

在点积注意力力中,softmax函数用于将注意力集中在相似度最高的Token上。然而,除了那些具有最大相似性的Token之外,Token也携带有意义的信息。基于这种直觉设计了一个轻量级模块,可以更直接地生成语义预测。为了实现这一点,将G指定为分割任务的类嵌入,将指定为ViT Backbone 的第层的输出。语义Mask与G中的每个Token配对,以表示每个类的语义预测。二进制MaskM定义如下:

Mask的形状为,可以reshape为,并可二次上采样为原始图像大小。ATM机制如图3的右部分所示,在交叉注意力过程中产生Mask作为其中间输出。

ATM模块的最终输出Token用于分类。用参数化的全连通层(FC)和Softmax函数来预测对象类是否存在于图像中。类预测被形式定义为:

这里,表示类别c出现在图像中的可能性。为了简单起见,将称为类的概率分数。

类的输出分割图是通过reshape的类特定Mask 及其相应的预测分数的逐元素相乘获得的。在推理过程中,通过使用,选择得分最高的类,将标签分配给每个像素。

事实上,像ViT这样的普通基础模型并不固有地具有具有不同规模特征的多个阶段。因此,诸如合并来自多个尺度的特征的特征金字塔网络(FPN)之类的结构不适用于它们。

尽管如此,ViT中除最后一个层之外的层的特征包含有价 Value 的Low-level语义信息,这有助于提高性能。在SegViT中开发了一种结构,利用来自不同ViT层的特征图来丰富特征表示。这使我们能够合并这些特征图中存在的丰富的Low-level语义信息并从中受益。

SegViT是通过分类损失和二进制Mask损失来训练的。分类损失()使类别预测和实际目标之间的交叉熵最小化。Mask损失()由Focal Loss和Dice Loss组成,用于优化分割精度并解决Mask预测中的样本不平衡问题。Dice Loss和Focal Loss分别使预测Mask和GT分割之间的Dice和Focal得分最小化。最终损失是各项损失的组合,正式定义为:

其中λλ是控制每个损失函数强度的超参数。先前的Mask Transformer方法,如MaskFormer和DETR,已经采用了二进制Mask Loss,并通过经验实验微调了它们的超参数。

因此,为了一致性直接使用与MaskFormer和DETR相同的 Value 作为损失超参数:λλ

2.2、Shrunk Structure for Efficient Plain ViT Encoder

最近的跟踪,如DynamicViT、TokenLearner和SPViT,提出了Token修剪技术来加速Vision Transformer。然而,这些方法中的大多数是专门为图像分类任务设计的,因此会丢弃有价 Value 的信息。当这些技术应用于语义分割时,它们可能无法保留精确密集预测任务所需的高分辨率特征。

在本文中提出了Shrunk结构,该结构利用基于Query的下采样(QD)来修剪输入Token序列,并利用Query上采样(QU)来恢复丢弃的Token,从而保留对语义分割至关重要的精细细节特征。QD和QU的总体架构如图4所示。

对于QD,重新设计了Transformer编码器块,并引入了高效的下采样操作,以专门减少QueryToken的数量。在Transformer编码器层中,计算成本直接受Query Token数量的影响,输出大小由Query Token大小决定。为了在保持信息完整性的同时减轻计算负担,一种可行的策略是在保留Key和 Value Token的同时选择性地减少Query Token的数量。这种方法允许有效地减小当前层的输出大小,从而降低后续层的计算成本。

对于QU,通过使用具有更高分辨率的预定义或继承的Token序列作为QueryToken来实现上采样。Key和 Value Token取自从 Backbone 获得的Token序列, Backbone 通常具有较低的分辨率。输出大小由具有更高分辨率的QueryToken决定。通过交叉注意力机制,来自Key和 Value Token的信息被集成到输出中。该过程促进了信息的非线性合并,并展示了上采样行为,有效地提高了输出的分辨率。

如图5所示提出的Shrunk结构包含QD和QU模块。具体来说,在ViT Backbone 的中间深度集成了QD操作,正好在24层 Backbone 的第8层。QD操作使用2×2最近邻下采样操作对QueryToken进行下采样,从而将特征图大小减小到1/32。

然而,这种下采样可能会导致信息丢失和性能下降。为了缓解这个问题,在应用QD运算之前,对特征图采用了QU运算。这包括初始化一组分辨率为1/16的QueryToken来存储信息。随后,随着下采样的特征图在剩余的 Backbone 层中前进,使用另一个QU操作与先前存储的1/16高分辨率特征图一起对其进行合并和上采样。这个迭代过程最终生成了一个1/16的高分辨率特征图,该特征图富含 Backbone 处理的语义信息。

尽管所提出的Shrunk方法在保持性能方面是有效的,但它需要将QD操作集成到 Backbone 的中间层中。这种必要性是由于浅层主要捕获Low-level特征,并且对这些层应用下采样将导致显著的信息损失。因此,这些低层继续以更高的分辨率进行计算,限制了计算成本的潜在降低。

为了解决这一限制并进一步优化 Backbone ,作者引入了增强功能,并提出了一种称为Shrunk++的新架构。在这种增强的体系结构中,在QD部分引入了边缘检测模块,并引入了边缘Query下采样(EQD)技术来更新QD过程。除了消除每4个连续Token的2×2最近下采样操作外,本文的方法旨在保留包含多个类别的Token,特别是包含边的Token。通过保留2×2稀疏Token,保留了重要的语义信息,同时也保留了边缘Token以保留详细的空间信息。

通过保留这两种类型的信息,最大限度地减少了有价值信息的损失,并克服了与Low-level别层相关的限制。为了提取边缘,使用轻量级多层感知器(MLP)边缘检测头添加了一个单独的分支,该分支学习从输入图像中检测边缘。边缘检测头作为辅助分支工作,与ATM解码器同时训练。该Head处理输入图像,该图像具有与 Backbone 相同的尺寸。让输入图像具有与 Backbone 对齐的C通道。该Head中的多层感知器(MLP)由三层组成,尺寸分别为C、C/2和2。

设表示输入图像,并且MLP的输出可以定义为,其中、、是三层的权重。然后,输出通过激活函数,得到。为了确定属于边缘的Token的置信水平,应用阈 Value τ。在实现中,将τ设置为0.7。为了获得GT边缘,对GT分割图Y进行后处理。由于输入已经用patch大小P进行了Token化,对GT进行Token化,并将其reshape为Token序列,表示为,其中最后两个维度对应于patch维度。如果patch中存在任何边缘像素,认为patch包含边缘。将边缘Mask 定义如下:

对于S中的每个元素,创建一个二进制边缘Mask,如果τ。交叉熵损失是在生成的边缘Mask 和GT边缘Mask 之间计算的:。通过将边缘检测头作为辅助分支,Shrunk++架构在整个Query下采样过程中有效地保留了详细的空间上下文,形成了边缘Query下采样(EQD)结构。

这种EQD结构允许模型在备用下采样期间捕获和保存边缘信息,从而在不影响性能的情况下显著减少计算开销。EQD的集成使Shrunk++架构能够在计算效率和保持高性能水平之间取得显著的平衡。

2.3、连续语义分割的探索

连续语义分割的目的是在不忘记的情况下以T步训练分割模型。在步骤t,得到了一个数据集,它包括一组,其中是大小为的图像,是GT分割图。这里,仅由当前类中的标签组成,而所有其他类(即旧类或未来类)都被分配给背景。在持续学习中,步骤t的模型应该能够预测历史上的所有类别。

SegViT用于持续学习。现有的连续语义分割方法提出了正则化算法来保存特定架构DeepLabV3的过去知识。这些方法侧重于使用ResNet Backbone 的DeepLabV3的连续语义分割,该 Backbone 在区分不同类别时具有较差的鲁棒性Vision 表示。因此,这些方法需要微调模型参数来学习新类,同时试图保留旧类的知识。

不幸的是,调整用于前一项任务的旧参数不可避免地会干扰过去的知识,导致灾难性的遗忘。相反,本文的方法SegViT将类预测与Mask分割解耦,使其本质上适合于连续学习环境。通过利用普通Vision Transformer的强大表示能力,可以通过单独微调类代理(即类Token)来学习新的类,同时保持旧参数的冻结。这种方法消除了在学习新任务时微调旧参数的需要,有效地解决了灾难性遗忘的问题。

当在当前任务t上训练时,添加了一个新的可学习Token序列,其长度等于当前任务中的类。为了学习新的类,生长和训练新的ATM模块和用于Mask预测和Mask分类的全连接层。为了简单起见,忽略了ATM模块的并行结构。单个ATM模块是指多个ATM模块。设和表示任务的ATM模块和完全连接(FC)层的权重。先前任务的所有参数,包括ViT编码器、ATM模块和FC层,都被完全冻结。

图6展示了适用于连续语义分割的SegViT架构的概述。

给定编码器提取的特征和类Token ,ATM产生对应于Mask的Mask预测和输出Token :

基于等式4,通过在类Token 上应用FC来获得类预测P。将每个类别的预测得分与Mask 相乘以获得类别的分割图:

其中表示按元素相乘。通过取在每个像素中具有最高分数的类来获得分割的,定义为

基于任务的GT,使用等式5中定义的损失函数来训练SegViT。为了生成所有任务的最终分割,将所有任务的输出连接起来。

3、实验

3.1、SOTA对比

1、ADE20K

2、COCO-Stuff-10K

3、PASCAL-Context

3.2、消融实验

1、Effect of the ATM module

2、Ablation of the feature levels

3、SegViT on hierarchical base models

4、Ablation of Shrunk and Shrunk++ strategies

5、Ablation of the components in Shrunk structure

6、Ablation studies on decoder variances

7、Ablation for the QD module

3.3、应用1:一个更好的特征表示学习指标

3.4、应用2:连续的语义分割

3.5、结果于讨论

4、参考

[1].SegViTv2: Exploring Efficient and Continual Semantic Segmentation with Plain Vision Transformers.

相关实践学习
基于函数计算一键部署掌上游戏机
本场景介绍如何使用阿里云计算服务命令快速搭建一个掌上游戏机。
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
相关文章
|
19天前
|
机器学习/深度学习
YOLOv8改进 | 细节创新篇 | iAFF迭代注意力特征融合助力多目标细节涨点
YOLOv8改进 | 细节创新篇 | iAFF迭代注意力特征融合助力多目标细节涨点
162 0
|
15天前
|
存储 机器学习/深度学习 人工智能
论文介绍:InfLLM——揭示大型语言模型在无需训练的情况下处理极长序列的内在能力
【5月更文挑战第18天】InfLLM是一种新方法,无需额外训练即可增强大型语言模型处理极长序列的能力。通过使用记忆单元存储长序列的远距离上下文,InfLLM能更准确地捕捉长距离依赖,提高对长文本理解。实验表明,InfLLM使预训练在短序列上的模型在处理极长序列时表现媲美甚至超过专门训练的模型。尽管有挑战,如动态上下文分割和记忆单元效率,InfLLM为长序列处理提供了有效且未经训练的解决方案。论文链接:https://arxiv.org/abs/2402.04617
32 3
|
19天前
|
机器学习/深度学习 数据挖掘 测试技术
DETR即插即用 | RefineBox进一步细化DETR家族的检测框,无痛涨点
DETR即插即用 | RefineBox进一步细化DETR家族的检测框,无痛涨点
139 1
|
19天前
|
机器学习/深度学习
YOLOv5改进 | 细节创新篇 | iAFF迭代注意力特征融合助力多目标细节涨点
YOLOv5改进 | 细节创新篇 | iAFF迭代注意力特征融合助力多目标细节涨点
92 0
|
10月前
|
人工智能 缓存 并行计算
终极「揭秘」:GPT-4模型架构、训练成本、数据集信息都被扒出来了
终极「揭秘」:GPT-4模型架构、训练成本、数据集信息都被扒出来了
539 0
|
11月前
|
机器学习/深度学习
结合亲和力提高了 28.7 倍,基于端到端贝叶斯语言模型的方法设计大型、多样化的高亲和力抗体库
结合亲和力提高了 28.7 倍,基于端到端贝叶斯语言模型的方法设计大型、多样化的高亲和力抗体库
|
机器学习/深度学习 人工智能 算法
模型部署系列 | 一文告诉你AI模型QAT量化遇到震荡问题应该如何解决呢?(二)
模型部署系列 | 一文告诉你AI模型QAT量化遇到震荡问题应该如何解决呢?(二)
197 0
|
机器学习/深度学习 人工智能 算法
模型部署系列 | 一文告诉你AI模型QAT量化遇到震荡问题应该如何解决呢?(一)
模型部署系列 | 一文告诉你AI模型QAT量化遇到震荡问题应该如何解决呢?(一)
514 0
|
机器学习/深度学习 人工智能 算法
微软提出自动化神经网络训练剪枝框架OTO,一站式获得高性能轻量化模型
微软提出自动化神经网络训练剪枝框架OTO,一站式获得高性能轻量化模型
186 0
|
人工智能 自然语言处理 安全
表现优于 GPT-4,ChemCrow 集成 13 种化学工具,增强大型语言模型的化学性能
表现优于 GPT-4,ChemCrow 集成 13 种化学工具,增强大型语言模型的化学性能
254 0