3.1.2.2 位置编码模块
位置编码模块是为 Transformer 模块提供 Patch 和 Patch 之间的相对关系,非常关键。在通用任务的 Transformer 模型中认为一个好的位置编码应该要满足以下特性:
- 不同位置的位置编码向量应该是唯一的
- 不能因为不同位置位置编码的值大小导致网络学习有倾向性
- 必须是确定性的
- 最好能够泛化到任意长度的序列输入
ViT 位置编码模块满足前 3 条特性,但是最后一条不满足,当输入图片改变时候需要微调,比较麻烦。基于此也出现了不少的算法改进,结构图如下所示:
按照是否显式的设置位置编码向量,可以分成:
- 显式位置编码,其中可以分成绝对位置编码和相对位置编码。
- 隐式位置编码,即不再直接设置绝对和相对位置编码,而是基于图片语义利用模型自动生成能够区分位置信息的编码向量。
隐式位置编码对于图片长度改变场景更加有效,因为其是自适应图片语义而生成。
3.1.2.2.1 显式位置编码
显式位置编码,可以分成绝对位置编码和相对位置编码。
(1) 绝对位置编码
绝对位置编码表示在 Patch 的每个位置都加上一个不同的编码向量,其又可以分成固定位置编码即无需学习直接基于特定规则生成,常用的是 Attention is all you need 中采用的 sincos 编码,这种编码方式可以支持任意长度序列输入。还有一种是可学习绝对位置编码,即初始化设置为全 0 可学习参数,然后加到序列上一起通过网络训练学习,典型的例如 ViT、PVT 等等。
(2) 相对位置编码
相对位置编码考虑为相邻 Patch 位置编码,其实现一般是设置为可学习,例如 Swin Transformer 中采用的可学习相对位置编码,其做法是在 QK 矩阵计完相似度后,引入一个额外的可学习 bias 矩阵,其公式为:
Swin Transformer 这种做法依然无法解决图片尺寸改变时候对相对位置编码插值带来的性能下降的问题。在 Swin Transformer v2 中作者做了相关实验,在直接使用了在 256 * 256 分辨率大小,8 * 8 windows 大小下训练好的 Swin-Transformer 模型权重,载入到不同尺度的大模型下,在不同数据集上进行了测试,性能如下所示 (Parameterized position bias 这行):
每个表格中的两列表示没有 fintune 和有 fintune,可以看出如果直接对相对位置编码插值而不进行 fintune,性能下降比较严重。故在 Swin Transformer v2 中引入了对数空间连续相对位置编码 log-spaced continuous position bias,其主要目的是可以更有效地从低分辨权重迁移到高分辨率下游任务。
相比于直接应用可学习的相对位置编码,v2 中先引入了 Continuous relative position bias (CPB),
B 矩阵来自一个小型的网络,用来预测相对位置,该模块的输入依然是 Patch 间的相对位置,这个小型网络可以是一个 2 层 MLP,然后接中间采用激活函数连接。
其性能如上表的 Linear-Spaced CPB 所示,可以发现相比原先的相对位置编码性能有所提升,但是当模型尺度继续增加,图片尺寸继续扩大后性能依然会下降比较多,原因是预测目标空间是一个线性的空间。当 Windows 尺寸增大的时候,比如当载入的是 8*8 大小 windows 下训练好的模型权重,要在 16*16 大小的 windows 下进行 fine-tune,此时预测相对位置范围就会从 [-7,7] 增大到 [-15,15],整个预测范围的扩大了不少,这可能会出现网络不适应性。为此作者将预测的相对位置坐标从 linear space 改进到 log space 下,这样扩大范围就缩小了不少, 可以提供更加平滑的预测范围,这会增加稳定性,提升泛化能力,性能表如上的 Log-Spaced CPB 所示。
在 Swin Transformer 中相对位置编码矩阵 shape 和 QK矩阵计算后的矩阵一样大,其计算复杂度是 O(HW),当图片很大或者再引入 T 时间轴,那么计算量非常大, 故在 Improved MViT ,作者进行了分解设计,分成 H 轴相对位置编码,W 轴相对位置编码,然后相加,从而将复杂度降低为 O(H+W)。
关于绝对位置编码和相对位置编码到底哪个是最好的,目前还没有定论,在不同的论文实验中有不同的结论,暂时来看难分胜负。但是从上面可以分析来看,在 ViT 中不管是绝对位置编码和相对位置编码,当图片大小改变时候都需要对编码向量进行插值,性能都有不同程度的下降 ( Swin Transformer v2 在一定程度上解决了)。
3.1.2.2.2 隐式位置编码
当图片尺寸改变时候,隐式位置编码可以很好地避免显式位置编码需要对对编码向量进行插值的弊端。其核心做法都是基于图片语义自适应生成位置编码。
在论文 How much position information do convolutional neural networks encode? 中已经证明 CNN 不仅可以编码位置信息,而且越深的层所包含的位置信息越多,而位置信息是通过 zero-padding 透露的。既然 Conv 自带位置信息,那么可以利用这个特性来隐式的编码位置向量。大部分算法都直接借鉴了这一结论来增强位置编码,典型代表有 CPVT、PVTv2 和 CSwin Transformer 等。
CPVT 指出基于之前 CNN 分类经验,分类网络通常需要平移不变性,但是绝对位置编码会在一定程度打破这个特性,因为每个位置都会加上一个独一无二的位置编码。看起来似乎相对位置编码可以避免这个问题,其天然就可以适应不同长度输入,但是由于相对位置编码在图像分类任务中无法提供任何绝对位置信息(实际上相对位置编码也需要插值),而绝对位置信息被证明非常重要。以 DeiT-Tiny 模型为例,作者通过简单的对比实验让用户直观的感受不同位置编码的效果:
2D PRE 是指 2D 相对位置编码,Top-1@224 表示测试时候采用 224 图片大小,这个尺度和训练保持一致,Top-1@384 表示测试时候采用 384 图片大小,由于图片大小不一致,故需要对位置编码进行插值。从上表可以得出:
- 位置编码还是很重要,不使用位置编码性能很差。
- 2D 相对位置编码性能比其他两个差,可学习位置编码和 sin-cos 策略效果非常接近,相对来说可学习绝对位置编码效果更好一些(和其他论文结论不一致)。
- 在需要对位置编码进行插值时候,性能都有下降。
基于上述描述,作者认为在视觉任务中一个好的位置编码应满足如下条件:
- 模型应该具有 permutation-variant 和 translation-equivariance 特性,即对位置敏感但同时具有平移不变性。
- 能够自然地处理变长的图片序列。
- 能够一定程度上编码绝对位置信息。
基于这三个原则,CPVT 引入了一个带有 zero-padding 的卷积 ( kernel size k ≥ 3) 来隐式地编码位置信息,并提出了 Positional Encoding Generator (PEG) 模块,如下所示:
将输入序列 reshape 成图像空间维度,然后通过一个 kernel size 为 k ≥ 3, (k−1)/2 zero paddings 的 2D 卷积操作,最后再 reshape 成 token 序列。这个 PEG 模块因为引入了卷积,在计算位置编码时候会考虑邻近的 token,当图片尺度改变时候,这个特性可以避免性能下降问题。算法的整体结构图如下所示:
基于 CPVT 的做法,PVTv2 将 zero-padding 卷积思想引入到 FFN 模块中。
通过在常规 FFN 模块中引入 zero-padding 的逐深度卷积来引入隐式的位置编码信息(称为 Convolutional Feed-Forward)。
同样的,在 CSWin Transformer 中作者也引入了 3x3 DW 卷积来增强位置信息,结构图如下所示:
APE 是 ViT 中的绝对位置编码,CPE 是 CPVT 中的条件位置编码,其做法是和输入序列 X 相加,而 RPE 是 Swin Transformer 中所采用的相对位置编码,其是加到 QK 矩阵计算后输出中,而本文所提的 Locally-Enhanced Positional Encoding (LePE),是在自注意力计算完成后额外加上 DW 卷积值,计算量比 RPE 小。
LePE 做法对于下游密集预测任务中图片尺寸变化情况比较友好,性能下降比较少。
除了上述分析的诸多加法隐式位置编码改进, ResT 提出了另一个非常相似的,但是是乘法的改进策略,结构图如下所示:
对 Patch Embedding 后的序列应用先恢复空间结构,然后应用一个 3×3 depth-wise padding 1的卷积来提供位置注意力信息,然后通过 sigmoid 操作变成注意力权重和原始输入相乘。代码如下所示:
class PA(nn.Module): def __init__(self, dim): super().__init__() self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) self.sigmoid = nn.Sigmoid() def forward(self, x): # x 是已经恢复了空间结构的 patch embedding return x * self.sigmoid(self.pa_conv(x))
作者还指出,这个Pixel Attention (PA) 模块可以替换为任意空间注意力模块,性能优异,明显比位置编码更加灵活好用。
3.1.2.3 自注意力模块
Transformer 的最核心模块是自注意力模块,也就是我们常说的多头注意力模块,如下图所示:
注意力机制的最大优势是没有任何先验偏置,只要输入足够的数据就可以利用全局注意力学到泛化性能不错的特征。当数据量足够大的时候,注意力机制是 Transformer 模型的最大优势,但是一旦数据量不够就会变成逆势,后续很多算法改进方向都是希望能够引入部分先验偏置辅助模块,在减少对数据量的依赖情况下加快收敛,并进一步提升性能。同时注意力机制还有一个比较大的缺点:因为其全局注意力计算,当输入高分辨率图时候计算量非常巨大,这也是目前一大改进方向。
简单总结,可以将目前自注意力模块分成 2 个大方向:
- 仅仅包括全局注意力,例如 ViT、PVT 等。
- 引入额外的局部注意力,例如 Swin Transformer。
如果整个 Transformer 模型不含局部注意力模块,那么其主要改进方向就是如何减少空间全局注意力的计算量,而引入额外的局部注意力自然可以很好地解决空间全局注意力计算量过大的问题,但是如果仅仅包括局部注意力,则会导致性能下降严重,因为局部注意力没有考虑窗口间的信息交互,因此引入额外的局部注意力的意思是在引入局部注意力基础上,还需要存在窗口间交互模块,这个模块可以是全局注意力模块,也可以是任何可以实现这个功能的模块。其结构图如下所示:
3.1.2.3.1 仅包括全局注意力
标准的多头注意力就是典型的空间全局注意力模块,当输入图片比较大的时候,会导致序列个数非常多,此时注意力计算就会消耗大量计算量和显存。以常规的 COCO 目标检测下游任务为例,输入图片大小一般是 800x1333,此时 Transformer 中的自注意力模块计算量和内存占用会难以承受。其改进方向可以归纳为两类:减少全局注意力计算量以及采用广义线性注意力计算方式。
(1) 减少全局注意力计算量
全局注意力计算量主要是在 QK 矩阵和 Softmax 后和 V 相乘部分,想减少这部分计算量,那自然可以采用如下策略:
- 降低 KV 维度,QK 计算量和 Softmax 后和 V 相乘部分计算量自然会减少。
- 减低 QKV 维度,主要如果 Q 长度下降了,那么代表序列输出长度改变了,在减少计算量的同时也实现了下采样功能。
(a) 降低 KV 维度
降低 KV 维度做法的典型代码是 PVT,其设计了空间 Reduction 注意力层 (SRA) ,如下所示:
其做法比较简单,核心就是通过 Spatial Reduction 模块缩减 KV 的输入序列长度,KV 是空间图片转化为 Token 后的序列,可以考虑先还原出空间结构,然后通过卷积缩减维度,再次转化为序列结构,最后再算注意力。假设 QKV shape 是完全相同,其详细计算过程如下:
- 在暂时不考虑 batch 的情况下,KV 的 shape 是 (H'W', C),既然叫做空间维度缩减,那么肯定是作用在空间维度上,故首先利用 reshape 函数恢复空间维度变成 (H', W', C)。
- 然后在这个 shape 下应用 kernel_size 和 stride 为指定缩放率例如 8 的二维卷积,实现空间维度缩减,变成 (H/R, W/R, C), R 是缩放倍数。
- 然后再次反向 reshape 变成 (HW/(R平方), C),此时第一维(序列长度)就缩减了 R 平方倍数。
- 然后采用标准的多头注意力层进行注意力加权计算,输出维度依然是 (H'W', C)。
而在 Twins 中提出了所谓的 GSA,其实就是 PVT 中的空间缩减模块。
同时基于最新进展,在 PVTV2 算法中甚至可以将 PVTv1 的 Spatial Attention 直接换成无任何学习参数的 Average Pooling 模块,也就是所谓的 Linear SRA,如下所示:
同样参考 PVT 设计,在 P2T 也提出一种改进版本的金字塔 Pool 结构,如下所示:
(b) 即为改进的 Spatial Attention 结构,对 KV 值应用不同大小的 kernel 进行池化操作,最后 cat 拼接到一起,输入到 MHSA 中进行计算,通过控制 pool 的 kernel 就可以改变 KV 的输出序列长度,从而减少计算量,同时金字塔池化结构可以进一步提升性能(不过由于其 FFN 中引入了 DW 卷积,也有一定性能提升)。
从降低 KV 空间维度角度出发,ResT 算法中也提出了一个内存高效的注意力模块 E-MSA,相比 PVT 做法更近一步,不仅仅缩减 KV 空间维度,还同时加强各个 head 之间的信息交互,如下所示:
其出发点有两个:
- 当序列比较长或者维度比较高的时候,全局注意力计算量过大。
- 当多头注意力计算时,各个头是按照 D 维度切分,然后独立计算最后拼接输出,各个头之间没有交互,当 X 维度较少时,性能可能不太行。
基于上述两点,作者引入 DWConv 缩放 KV 的空间维度来减少全局注意力计算量,然后在 QK 点乘后引入 1x1 Conv 模块进行多头信息交互。其详细做法如下:
- 假设输入序列 X Shape 是 nxd,n 表示序列长度,d 表示每个序列的嵌入向量维度。
- 假设想将特征图下采样 sxs 倍,可以将 X 输入到 kernel 为 (s+1,s+1),stride 为 (s, s), padding 为 (s//2, s//2) 的 DW 卷积和 LN 层中,假设输出变成 (h'w', d)。
- 将其经过线性映射,然后在 d 维度切分成 k 个部分,分别用于 k 个头中。
- QK 计算点积和 Scale 后,Shape 变成 (k, n, n'),然后对该输出采用 1x1 卷积在头的 k 维度进行多个 head 之间的信息聚合。
- 后续是标准的注意力计算方式。
其核心代码如下所示:
class Attention(nn.Module): def __init__(self, dim=32, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., sr_ratio=2): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 # (sr_ratio+1)x (sr_ratio+1) 的 DW 卷积 self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) self.sr_norm = nn.LayerNorm(dim) # 1x1 卷积 self.transform_conv = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, stride=1) self.transform_norm = nn.InstanceNorm2d(self.num_heads) def forward(self, x, H, W): B, N, C = x.shape q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 1 空间下采样 x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ = self.sr_norm(x_) kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] # 2 输出维度为 (B,num_head, N, N') attn = (q @ k.transpose(-2, -1)) * self.scale # 3 在 num_head 维度进行信息聚合,加强 head 之间的联系 attn = self.transform_conv(attn) attn = attn.softmax(dim=-1) attn = self.transform_norm(attn) # 4 子注意力模块标准操作 attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x
(b) 降低 QKV 维度
Multiscale Vision Transformers (MViT) 也考虑引入 Pool 算子来减少全局注意力计算量。MViT 主要为了视频识别而设计,因为视频任务本身就会消耗太多的显存和内存,如果不进行针对性设计,则难以实际应用。正如其名字所言,其主要是想参考 CNN 中的多尺度特性(浅层空间分辨率大通道数少,深层空间分辨率小通道数多)设计出适合 Transformer 的多尺度 ViT。自注意力模块如下所示:
相比 Transformer 自注意力模块,其主要改变是多了 Pool 模块,该模块的主要作用是通过控制其 Stride 参数来缩放输入的序列个数,而序列个数对应的是图片空间尺度 THW。以图像分类任务为例,
- 任意维度的序列 X 输入,首先和 3 个独立的线性层得到 QKV,维度都是 (HW, D)。
- 将QKV 恢复空间维度,变成 (H, W, D),然后经过独立的 3 个 Pool 模块,通过控制 stride 参数可以改变输出序列长度,变成 (H', W', D),设置 3 个 Pool 模块不同的 Stride 值可以实现不同大小的输出。
- 将输入都拉伸为序列格式,然后采用自注意力计算方式输出 (H'W‘, D)。
- 为了保证输出序列长度改变而无法直接应用残差连接,需要在 X 侧同时引入一个 Pool 模块将序列长度和维度变成一致。
由于 MViT 出色的性能,作者将该思想推广到更多的下游任务中(例如目标检测),提出了改进版本的 Imporved MViT,其重新设计的结构图如下所示:
Imporved MViT 在不同的下游任务提升显著。
(2) 广义线性注意力计算方式
基于 NLP 中 Transformer 进展,我们可以考虑将其引入到 ViT 中,典型的包括 Performer,其可以通过分解获得一个线性时间注意力机制,并重新排列矩阵乘法,以对常规注意力机制的结果进行近似,而不需要显示构建大小呈平方增长的注意力矩阵。在 T2T-ViT 算法中则直接使用了高效的 Performer。
在 NLP 领域类似的近似计算方式也有很多,由于本文主要关注 ViT 方面的改进,故这部分不在展开分析。
3.1.2.3.2 引入额外局部注意力
引入额外局部注意力的典型代表是 Swin Transformer,但是卷积模块工作方式也可以等价为局部注意力计算方式,所以从目前发展来看,主要可以分成 3 个大类:
- 局部窗口计算模式,例如 Swin Transformer 这种局部窗口内计算。
- 引入卷积局部归纳偏置增强,这种做法通常是引入或多或少的卷积来明确提供局部注意力功能。
- 稀疏注意力。
结构图如下所示:
需要特别注意的是:
- 引入局部窗口注意力后依然要提供跨窗口信息交互模块,不可能只存在局部注意力模块,因为这样就没有局部窗口间的信息交互,性能会出现不同程度的下降,也不符合 Transformer 设计思想( Patch 内和 Patch 间信息交互)。
- 局部窗口计算模式和引入卷积局部归纳偏置增强的划分依据是其核心出发点和作用来划分,而不是从是否包括 Conv 模块来区分。
(1) 局部窗口计算模式
局部注意力的典型算法是 Swin Transformer,其将自注意力计算过程限制在每个提前划分的窗口内部,称为窗口注意力 Window based Self-Attention (W-MSA),相比全局计算自注意力,明显可以减少计算量,但是这种做法没法让不同窗口进行交互,此时就退化成了 CNN,所以作者又提出移位窗口注意力模块 Shifted window based Self-Attention (SW-MSA),示意图如下所示,具体是将窗口进行右下移位,此时窗口数和窗口的空间切分方式就不一样了,然后将 W-MSA 和 SW-MSA 在不同 stage 之间交替使用,即可实现窗口内局部注意力计算和跨窗口的局部注意力计算,同时其要求 stage 个数必须是偶数。
大概流程为:
- 假设第 L 层的输入序列 Shape 是 (N, C),而 N 实际上是 (H, W) 拉伸而来。
- 将上述序列还原为图片维度即(H, W, C), 假设指定每个窗口大小是 7x7,则可以将上述图片划分为 HW/49 个块,然后对每个块单独进行自注意力计算(具体实现上可以矩阵并行),这样就将整个图像的全局自注意力计算限制在了窗口内部即 W-MSA 过程。
- 为了加强窗口间的信息交流,在 L+1 层需要将 W-MSA 换成 SW-MSA,具体是将 L 层的输出序列进行 shift 移位操作,如上图所示,从 4 个窗口就变成了 9 个窗口,此时移位后的窗口包含了原本相邻窗口的元素,有点像窗口重组了,如果在这 9 个窗口内再次计算 W-MSA 其输出就已经包括了 L 层窗口间的交互信息了。
上述只是原理概述,实际上为了保证上述操作非常高效,作者对代码进行了非常多的优化,相对来说是比较复杂的。值得注意的是 Swin Transformer 相比其他算法(例如 PVT )非常高效,因为整个算法中始终不存在全局注意力计算模块( SW-MSA 起到类似功能),当图片分辨率非常高的时候,也依然是线性复杂度的,这是其突出优点。凭借其局部窗口注意力机制,刷新了很多下游任务的 SOTA,影响非常深远。
在 Swin Transformer v2 中探讨了模型尺度和输入图片增大时候,整个架构的适应性和性能。在大模型实验中作者观察到某些 block 或者 head 中的 attention map 会被某些特征主导,产生这个的原因是原始 self-attention 中对于两两特征之间的相似度衡量是用的内积,可能会出现某些特征 pair 内积过大。为了改善这个问题,作者将内积相似度替换为了余弦相似度,因为余弦函数本身的取值范围本身就相当于是被归一化后的结果,可以改善因为些特征 pair 内积过大,主导了 attention 的情况,结构图如下所示:
Swin Transformer 算法在解决图片尺度增加带来的巨大计算量问题上有不错的解决方案,但是 SW-MSA 这个结构被后续诸多文章吐槽,主要包括:
- 为了能够高效计算,SW-MSA 实现过于复杂。
- SW-MSA 对 CPU 设备不友好,难以部署。
- 或许有更简单更优雅的跨窗口交互机制。
基于这三个问题,后续学者提出了大量的针对性改进,可以归纳为两个方向:
- 抛弃 SW-MSA,依然需要全局注意力计算模块,意思是不再需要 SW-MSA,跨窗口交互功能由全局注意力计算模块代替,当然这个全局注意力模块是带有减少计算量功能的。
- 抛弃 SW-MSA,跨窗口信息交互由特定模块提供,这个特定模块就是改进论文所提出的模块。
(a) 抛弃 SW-MSA,依然需要全局注意力计算模块
Imporved MViT 在改进的 Pool Attention 基础上,参考 Swin Transformer 在不同 stage 间混合局部注意力 W-MSA 和 SW-MSA 设计,提出 HSwin 结构,在 4 个 stage 中的最后三个 stage 的最后一个 block 用全局注意力 Pool Attention 模块(具体间 3.1.2.3.1 小节),其余 stage 的 block 使用 W-MSA ,实验表明这种设计比 Swin Transformer 性能更强,也更简单。
同样 Twins 也借鉴了 W-MSA 做法,只不过由于其位置编码是隐式生成,故不再需要相对位置编码,而且 SW-MSA 这种 Shift 算子不好部署,所以作者的做法是在每个 Encoder 中分别嵌入 Locally-grouped self-attention (LSA) 模块即不带相对位置编码的 W-MSA 以及 GSA 模块,GSA 就是 PVT 中使用的带空间缩减的全局自注意力模块,通过 LSA 计算局部窗口内的注意力,然后通过全局自注意力模块 GSA 计算窗口间的注意力,结构图如下所示:
(b) 抛弃 SW-MSA,跨窗口信息交互由特定模块提供
参考 CNN 网络设计思想,可以设计跨窗口信息交互模块,典型的论文包括 MSG-T 、Glance-and-Gaze Transformer 和 Shuffle Transformer。
MSG-Transformer 基于 W-MSA,通过引入一个 MSG Token 来加强局部窗口之间的信息交互即在每个窗口内额外引入一个 MSG Token,该 Token 主要作用就是进行窗口间信息传递,所设计的模块优点包括对 CPU 设备友好,推理速度比 SWin 快,性能也更好一些。结构图如下所示:
假设将图片或者输入序列划分为 4x4 个窗口,在每个窗口内部再划分为 2x2 个 shuffle 区域。- 在每个窗口中会额外拼接一个可学习的 MSG Token (三角形),故一共需要拼接 2x2 个可学习的 MSG Token。
- 将拼接后的所有 token 经过 layer norm、Swin Transformer 中的 W-MSA 和残差连接后,同一个窗口内的 token 会进行注意力计算,从而进行窗口内信息融合。
- 单独对 2x2 个 MSG Token 进行 shuffle 操作,交互 2x2 个 token 信息。
- 然后对输出再次进行 layer norm、Channel MLP 和残差连接后输出即可。
在第 3 步的 W-MSA 计算中,可以认为同一个窗口内会进行信息流通,从而 2x2 个 MSG Token 都已经融合了对应窗口内的信息,然后经过第 4 步骤 MSG Token 交换后就实现了局部窗口间信息的交互。MSG Token 信息交互模块完成了 Swin Transformer 中 SW-MSA ,相比 SW-MSA 算子,不管是计算复杂度还是实现难度都非常小。Shuffle 计算过程和 ShuffleNet 做法完全一样,如下所示:
将 Swin Transformer 中的 block 全部换成 MSG-Transformer block ,通过实验验证了本结构的优异性。
Shuffle Transformer 也是从效率角度对 Swin Transformer 的 SW-MSA 进行改进,正如其名字,其是通过 Shuffle 操作来加强窗口间信息交流,而不再需要 SW-MSA,由于其做法和 ShuffleNet 一致就不再详细说明,核心思想如下所示 (c):
将 Swin Transformer 中的 SW-MSA 全部换成 Shuffle W-MSA,在此基础上还引入了额外的 NWC 模块,其是一个 DW Conv,其 kernel size 和 window size 一样,用于增强邻近窗口信息交互,Shuffle Transformer 的 block 结构如下所示:
在 ImageNet 数据集上,同等条件上 Shuffle Transformer相比 Swin 有明显提升,在 COCO 数据集上,基于 Mask R-CNN,Shuffle Transformer 和 Swin 性能不相上下。
因为 Swin Transformer 不存在 NWC 模块,作者也进行了相应的对比实验:
这也进一步验证了引入适当的 CNN 局部算子可以在几乎不增加计算量的前提下显著提升性能。
MSG-Transformer 和 Shuffle Transformer 都是通过直接替换 SW-MSA 模块来实现的,Glance-and-Gaze Transformer (GG-Transformer) 则认为没有必要分成两个独立的模块,只需要通过同一个模块的两个分支融合就可以同时实现 W-MSA 和 SW-MSA 功能。结构图如下所示:
其提出一种新的局部窗口注意力计算机制,相比常规的近邻划分窗口,其采用了自适应空洞率窗口划分方式,上图是假设空洞率是 2 即每隔 1 个位置,这样就可以将图片划分为 4 个窗口,由于其采样划分方式会横跨多个像素位置,相比 Swin Transofrmer 划分方式具有更大的感受野,不断 stage 堆叠就可以实现全局感受野。在 Glance 分支中采用 MSA 局部窗口计算方法计算局部注意力,同时为了增强窗口之间的交互,其将 V 值还原为原先划分模式,然后应用 depth-wise conv 来提取局部信息,再通过自适应空洞划分操作的逆操作还原,再加上 Attention 后的特征。
Glance 分支用于在划分窗口内单独计算窗口内的局部注意力,由于其自适应空洞率窗口划分方式,使其能够具备全局注意力提取能力,而 Gaze分支用于在划分的窗口间进行信息融合,具备窗口间局部特征提取能力。将 Swin Transformer 中的 block 全部换成 GG-Transformer block ,通过实验验证了其性能优于 Swin Transformer 。
在改进 Swin Transformer 的窗口注意力计算方式这方面,CSWin Transformer 相比其余改进更加独特,其提出了十字架形状的局部窗口划分方式,如下图所示:
假设一共将图片划分成了 9 个窗口,本文所提注意力的计算只会涉及到上下左右中 5 个窗口,同时为了进一步减少计算量,又分成 horizontal stripes self-attention 和 vertical stripes self-attention,每个自注意力模块都只涉及到其中 3 个窗口,这种计算方式同时兼顾了局部窗口计算和跨窗口计算,一步到位。所谓 horizontal stripes self-attention 是指沿着 H 维度将 Tokens 分成水平条状 windows,假设一共包括 k 个头,则前 k/2 个头用于计算 horizontal stripes self-attention,后面 k/2 个头用于计算 vertical stripes self-attention。两组self-attention是并行的,计算完成后将 Tokens 的特征 concat 在一起,这样就构成了CSW self-attention,最终效果就是在十字形窗口内做 Attention,可以看到 CSW self-attention 的感受野要比常规的 Window Attention 的感受野更大。可以通过控制每个条纹的宽度来控制自注意力模块的复杂度,默认 4 个 stage 的条纹宽度分别设为 1, 2, 7, 7(图片空间维度比较大的时候采用较小的条纹宽度,减少计算量)。
(2) 引入卷积的局部归纳偏置能力
上述都是属于 Swin Transformer 改进,在引入卷积局部归纳偏置增强方面,典型算法为 ViTAE 和 ELSA,ViTAE 结构图如下所示:
其包括两个核心模块:reduction cell (RC) 和 normal cell (NC)。RC 用于对输入图像进行下采样并将其嵌入到具有丰富多尺度上下文的 token 中,而 NC 旨在对 token 序列中的局部性和全局依赖性进行联合建模,可以看到,这两种类型的结构共享一个简单的基本结构。
对于 RC 模块,分成两个分支,第一条分支首先将特征图输入到不同空洞率并行的卷积中,提取多尺度特征的同时也减少分辨率,输出特征图拼接+ GeLU 激活,然后输入到注意力模块中,第二条分支是纯粹的 Conv 局部特征提取,用于加强局部归纳偏置,两个分支内容相加,然后输入到 FFN 模块中。
对于 NC 模块,类似分成两个分支,第一条是注意力分支,第二条是 Conv 局部特征提取,用于加强局部归纳偏置,两个分支内容相加,然后输入到 FFN 模块中。基于上述模块,构建了两个典型网络,如下所示:
至于为何要如此设置以及各个模块的前后位置,作者进行了大量的实验研究:
ELSA: Enhanced Local Self-Attention for Vision Transformer基于一个现状:Swin Transformer 中所提的局部自注意力(LSA)的性能与卷积不相上下,甚至不如动态过滤器。如果是这样,那么 LSA 的重要性就值得怀疑了。最近也有很多学者发现了个奇怪的现象,例如Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight 中深入分析了注意力和 Conv 的关系,特别是 DW Conv,但是大部分都没有深入探讨 LSA 性能如此平庸的具体原因,本文从这个方面入手,并提出了针对性改进。
2022-01-26 18:00
文章来源:【OpenMMLab】