ICCV2021 | Swin Transformer: 使用移位窗口的分层视觉Transformer

简介: 本文解读的论文是ICCV2021中的最佳论文,在短短几个月内,google scholar上有388引用次数,github上有6.1k star。

Motivation


论文试图扩展Transformer的适用性,使其可以作为计算机视觉的通用主干,就像它在NLP中所做的那样,也可以像CNNs在视觉中所做的那样。


论文提到,将其在语言领域的高性能转换到视觉领域的重大挑战可以用这两种模式之间的差异来解释。这些不同之处之一涉及到规模。


与作为语言transformer中处理的基本元素的单词tokens不同,视觉元素在尺度上可以有很大的变化,这是一个在诸如目标检测之类的任务中受到关注的问题。在现有的基于transformer的模型中,tokens都是固定比例的,这一特性不适合这些视觉应用


另一个不同之处在于,与文本段落中的文字相比,图像中像素的分辨率要高得多。存在许多视觉任务,如语义分割,需要在像素级别进行密集预测,这对于高分辨率图像上的Transformer来说是很困难的,因为它的self-attention的计算复杂度是图像大小的二次方

 

创新思路


为了克服这些问题,论文提出了一种通用的Transformer骨干网,称为Swin Transformer,它构造了分层的特征映射,并且计算复杂度与图像大小成线性关系。

c1847f9a99e2b6829687c8c5023fbc85.png

如图1(A)所示,Swin Transformer通过从小块(灰色轮廓)开始,逐渐合并更深的Transformer层中的相邻块来构建分层表示


有了这些分层的特征图,Swin Transformer模型可以方便地利用先进的技术进行密集预测,如特征金字塔网络(FPN)或U-Net。线性计算复杂度是通过在分割图像(红色轮廓)的非重叠窗口内局部计算self-attention来实现的。每个窗口中的patches数量是固定的,因此复杂度与图像大小成线性关系


这些优点使得Swin Transformer适合作为各种视觉任务的通用主干,而不是以前基于Transformer的架构,后者生成单一分辨率的特征地图,并且具有二次方复杂性。

1f8a211b78a73f10e2b83ba5aa0b1150.png

Swin Transformer的一个关键设计元素是窗口分区在连续的self-attention层之间的移动,如图2所示。移动的窗口桥接了前一层的窗口,提供了它们之间的连接,显著增强了建模能力


这种策略在实际延迟方面也是有效的:一个窗口内的所有query patch都共享相同的key集,这便于硬件中的内存访问。相反,较早的基于滑动窗口的self-attention方法由于不同query像素的不同key集而在一般硬件上受到低延迟的影响。


实验表明,所提出的移位窗口方法比滑动窗口方法具有更低的延迟,但在建模能力上是相似的。事实证明,移位窗口方法对于全MLP体系结构也是有益的。


Methods


Overall Architecture


Swin Transformer架构的概述如图3所示,它展示了tiny版本(Swin-T)。


eafc05bb03e664a2ba574e3ba29d0c9b.png

图3.(a)Swin Transformer(Swin-T)的架构;(b)两个连续的Swin Transformer块(用公式表示(3))。W-MSA和SW-MSA分别是具有规则和移位窗口配置的多头自注意模块。


它首先通过patch分割模块(如ViT)将输入的RGB图像分割成不重叠的patch。每个patch都被视为一个“token”,其特征被设置为原始像素RGB值的串联。在实现中,论文使用了4×4的块大小,因此每个块的特征维度是4×4×3=48。将线性嵌入层应用于该原始值特征以将其投影到任意维度(表示为C)。


在这些patch tokens上应用了几个带有修改的self-attention计算的transformer block (Swin Transformer block)。transformer块保持tokens数(H/4×W/4),与线性嵌入一起称为“Stage1”。


为了产生分层表示,随着网络的深入,通过patch合并层来减少tokens的数量。第一个patch合并层将每组2×2相邻patch的特征进行拼接,并在4C维拼接的特征上应用线性层。这将tokens的数量减少了2×2=4的倍数(2倍下采样),并且输出维度被设置为2C。然后应用Swin Transformer块进行特征变换,分辨率保持为H/8×W/8。这第一个块的拼接和特征变换称为“Stage2”。重复“Stage3”和“Stage4”两次,输出分辨率分别为H/16×W/16和H/32×W/32。


这些Stage共同产生具有与典型卷积网络(如VGG和ResNet)相同的特征映射分辨率的分层表示。因此,该体系结构可以方便地取代现有方法中的骨干网络,用于各种视觉任务。

 欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读


Swin Transformer Block: Swin Transformer通过将transformer块中的标准多头self-attention(MSA)模块替换为基于移位窗口的模块,在保持其他层不变的情况下构建Swin Transformer。

 1daf33c0574979e91286d6a72c1d605a.png

如图3(b)所示,Swin Transformer模块由一个基于移位窗口的MSA模块和一个中间带有GELU非线性的两层MLP组成。在每个MSA模块和每个MLP之前应用LayerNorm(LN)层,并且在每个模块之后应用残差连接。

 

基于移位窗口的self-attention


非重叠窗口中的self-attention:  为有效建模,论文提出在局部窗口中计算self-attention。窗口被布置成以不重叠的方式均匀地分割图像。假设每个窗口包含M×M个patch,全局MSA模块和基于h×w patch图像的窗口的计算复杂度分别为

0c8aabf246e58aba3f7967599d4bdc22.png0c8aabf246e58aba3f7967599d4bdc22.png

其中,前者与patch数HW为平方关系,后者在M固定时是线性的(缺省情况下设置为7)。全局self-attention计算对于大型硬件来说通常是负担不起的,而基于窗口的self-attention是可伸缩的。

 


在连续块中移动窗口分区:  基于窗口的self-attention模块缺少跨窗口的连接,这限制了其建模能力。为了在保持非重叠窗口计算效率的同时引入跨窗口连接,论文提出了一种移位窗口划分方法,该方法在连续Swin Transformer块中的两种划分配置之间交替。

1f8a211b78a73f10e2b83ba5aa0b1150.png

在Swin Transformer架构中计算self-attentioin的移位窗口方法的图示。在Layer1(左)中,采用规则的窗口划分方案,并在每个窗口内计算自我关注。在下一层l+1(右)中,窗口分区被移位,从而产生新窗口。新窗口中的self-attention计算跨越了层l中先前窗口的边界,提供了它们之间的连接。


如图所示,第一个模块使用从左上角像素开始的规则窗口划分策略,将8×8特征图均匀划分为大小为4×4(M=4)的2×2个窗口。然后,下一模块通过将窗口从规则划分的窗口移位(M/2,M/2)(向下取整)像素来采用从前一层的窗口移位的窗口配置。使用移位窗口分区方法,连续的Swin Transformer块计算为

440375af03e96a3c96ca36e126b27733.png

其中,ˆzl和zl分别表示块1的(S)WMSA模块和MLP模块的输出特征;W-MSA和SW-MSA分别表示使用规则和移位窗口分区配置的基于窗口的多头self-attention。

移位窗口划分方法引入了前一层相邻非重叠窗口之间的连接,在图像分类、目标检测和语义分割中被发现是有效的。

 


移位的高效批处理计算:移位窗口分区的一个问题是,它将在移位中产生更多窗口,从h/M x w/M(向上取整)到(h/M + 1) x (w/M+1)(向上取整),并且一些窗口将比MxM更小。一个原始的解决方案是将较小的窗口填充到M×M的大小,并在计算注意力时屏蔽填充的值。当规则分区中的窗口数量较小时,例如2×2,使用这种朴素的解决方案增加的计算量是相当可观的(2×2→3×3,是2.25倍)。


95cefc01e027d3158aede683bc13f3f3.png

在这里,论文提出了更有效的批处理计算方法,即向左上角方向循环移动,如图所示。在这种转移之后,批处理窗口可能由特征图中不相邻的几个子窗口组成,因此采用mask机制将self-attention计算限制在每个子窗口内。在循环移位的情况下,批处理窗口的数量与常规窗口划分的数量相同,因此也是有效的。

 


相对位置偏差:在计算self-attention时,在计算相似度时将每个头部的相对位置偏差B(大小为M^2×M^2)包括在内:

c3849f7d7f3192e49a18448e40f1c814.png

其中Q,K,V大小为M^2 x d;d的大小为query/key,M^2是一个窗口中的patches数量。由于沿每个轴的相对位置在[−M+1,M−1]范围内,将较小尺寸的偏置矩阵ˆB(大小为(2M−1)×(2M−1))参数化,并且B中的值取自ˆB。

14a2823b37895b5ec128aca82c0dcacb.png

如表所示,论文提到,与没有这种bias项或使用绝对位置嵌入的同行相比,有显著的改进。进一步向输入添加绝对位置嵌入会略微降低性能,因此在论文的实现中不采用它。在预训练中学习到的相对位置偏差还可以用于通过双三次插值来初始化具有不同窗口大小的微调模型。

 

Architecture Variants


论文构建了名为Swin-B的基本模型,其模型大小和计算复杂度与ViTB/Deit-B相似。还提出了Swin-T、Swin-S和Swin-L,它们的模型规模和计算复杂度分别约为0.25×、0.5×和2倍。请注意,Swin-T和Swin-S的复杂度分别与ResNet-50(Deit-S)和ResNet-101相似。默认情况下,窗口大小设置为M=7。对于所有实验,每个头的query维度为D=32,每个MLP的扩展层为α=4。这些模型变体的体系结构超参数包括:

00dcbdabff6a68d5fe18199f41dc6243.png

Conclusion


论文提出的Swin Transformer在图像分类、目标检测和语义分割等识别任务中取得了较好的性能。它在三个任务上的延迟与Vit/Deit和ResNe(X)t模型相比要高得多。

1. 不同骨干网在ImageNet-1K分类上的比较。

894179946e3c4543f99269926216a072.png

2. 其在COCO测试开发集上的58.7box AP和51.1mask AP超过了之前SOTA结果+2.7box AP(无外部数据的复制-粘贴)和+2.6mask AP(DetectoRS)。


3fa15a946520b7437bff7e7bcbdfef8e.png

3.在ADE20K语义分割上,它在Val集合上获得了53.5mIoU,比之前的SOTA(SETR])提高了+3.2mIoU。在ImageNet-1K图像分类上达到了87.3%的TOP-1正确率。

8359bfed37eea9b2d67f9a7e0e7477b4.png

4. 不同的self-attention计算方法和实现在V100 GPU上的真实速度。

7a4165178bbea07d31524d26fc90ff0c.png

相关文章
|
机器学习/深度学习 存储 编解码
高效神经网络架构的正确打开方式! | EMO:结合 CNN 和 Transformer
高效神经网络架构的正确打开方式! | EMO:结合 CNN 和 Transformer
1220 0
|
5月前
|
机器学习/深度学习 自然语言处理 并行计算
【YOLOv8改进 -注意力机制】Mamba之MLLAttention :基于Mamba和线性注意力Transformer的模型
YOLOv8专栏探讨了该目标检测模型的创新改进,包括使用Mamba模型的线性注意力Transformer变体,称为MLLA。Mamba的成功关键在于遗忘门和块设计,MLLA结合了这些优点,提升了视觉任务的性能。文章提供全面分析,并提出MLLA模型,其在效率和准确性上超过多种视觉模型。论文和代码可在提供的链接中找到。MLLA Block的代码示例展示了如何整合关键组件以实现高效运算。更多配置详情见相关链接。
|
7月前
|
机器学习/深度学习 编解码 定位技术
【论文速递】ECCV2022 - 开销聚合与四维卷积Swin Transformer_小样本分割
【论文速递】ECCV2022 - 开销聚合与四维卷积Swin Transformer_小样本分割
|
机器学习/深度学习 自然语言处理 算法
从Transformer到ViT:多模态编码器算法原理解析与实现
从Transformer到ViT:多模态编码器算法原理解析与实现
638 0
|
机器学习/深度学习 编解码 文字识别
语义分割新SOTA | 当UNet与HRNet碰撞会产生怎样的火花?U-HRNet不做选择!!!
语义分割新SOTA | 当UNet与HRNet碰撞会产生怎样的火花?U-HRNet不做选择!!!
293 0
|
机器学习/深度学习 编解码 计算机视觉
Transformer新SOTA | 超越SWin、CSWin,MAFormer再探ViT Backbone新高度
Transformer新SOTA | 超越SWin、CSWin,MAFormer再探ViT Backbone新高度
274 0
|
机器学习/深度学习 编解码 PyTorch
金字塔ViT | 华为提出使用金字塔结构改进Transformer,涨点明显(Pytorch逐行解读)
金字塔ViT | 华为提出使用金字塔结构改进Transformer,涨点明显(Pytorch逐行解读)
302 0
|
机器学习/深度学习 存储 编解码
详细解读PVT-v2 | 教你如何提升金字塔Transformer的性能?(附论文下载)(一)
详细解读PVT-v2 | 教你如何提升金字塔Transformer的性能?(附论文下载)(一)
608 0
|
编解码 数据挖掘 计算机视觉
详细解读PVT-v2 | 教你如何提升金字塔Transformer的性能?(附论文下载)(二)
详细解读PVT-v2 | 教你如何提升金字塔Transformer的性能?(附论文下载)(二)
469 0
|
机器学习/深度学习 PyTorch 算法框架/工具
即插即用 | 5行代码实现NAM注意力机制让ResNet、MobileNet轻松涨点(超越CBAM)
即插即用 | 5行代码实现NAM注意力机制让ResNet、MobileNet轻松涨点(超越CBAM)
436 0