再次介绍一下我的专栏,很适合大家初入深度学习或者是Pytorch和Keras,希望这能够帮助初学深度学习的同学一个入门Pytorch或者Keras的项目和在这之中更加了解Pytorch&Keras和各个图像分类的模型
他有比较清晰的可视化结构和架构,除此之外,我是用jupyter写的,所以说在文章整体架构可以说是非常清晰,可以帮助你快速学习到各个模块的知识,而不是通过 python脚本Q一行一行的看,这样的方式是符合初学者的。
除此之外,如果你需要变成脚本形式,也是很简单的。
这里贴一下汇总篇: 汇总篇
4.定义网络(Swin Transformer)
自从Transformer在NLPQ任务上取得突破性的进展之后,业内一直尝试着把Transformer用于在CV领域。之前的若干尝试,例如iGPT,ViT都是将Transformer用在了图像分类领域,ViT我们之前也有介绍过在图像分类上的方法,但目前这些方法都有两个非常严峻的问题
1.受限于图像的矩阵性质,一个能表达信息的图片往往至少需要几百个像素点,而建模这种几百个长序列的数据恰恰是Transformer的天生缺陷
2.目前的基于Transformer框架更多的是用来进行图像分类,理论上来进解决检测问题应该也比较容易,但是对实例分割这种密集预测的场景Transformer并不擅长解决。
而这篇微软亚洲研究院提出的的Swin Transformer解决了这两个问题,并且在分类,检测,分割任务上都取得了SOTA的效果,同时获得了ICCV2021的best paper。Swin Transformer的最大贡献是提出了一个可以广泛应用到所有计算机视觉领域的backbone,并且大多数在CNN网络中常见的超参数在Swin Transformer中也是可以人工调整的,例如可以调整的网络块数,每一块的层数,输入图像的大小等等。该网络架构的设计非常巧妙,是一个非常精彩的将Transformer应用到图像领域的结构,值得每个AI领域的人前去学习。
实际上的,Swin Transformer 是在 Vision Transformer 的基础上使用滑动窗口 (shifted windows,SW)进行改造而来。它将 Vision Transformer 中固定大小的采样快按照层次分成不同大小的块(Windows),每一个块之间的信息并不共通、独立运算从而大大提高了计算效率。从 SwinTransformer 的架构图中可以看出其与 Vision Transformer 的结构很相似,不同的点在于其采用的Transformer Block 是由两个连续的 Swin Transformer Block 构成的,这两个 Block 块与 VisionTransformer中的 Block 块大致相同,只是将 Multi-head Self-Attention (MSA) 替换成了含有不同大小Windows 的 W-MSA与SW-MAS (具有滑动窗.SW),通过 Windows 和 Shifted Windows 的Multi-head Self-Attention 提高运算效率并最终提高分类的准确率。
Swin Transformer整体架构
从 Swin Transformer 网络的整体框架图我们可以看到,首先将输入图像1输入到 Patch Partition 进行一个分块操作,这一部分其实是和VIT是一样的,然后送入 Linear Embedding 模块中进行通道数channel 的调整。最后通过 stage 1,2,3 和 4 的特征提取和下采样得到最终的预测结果,值得注意的是每经过一个 stage,size 就会缩小为原来的 1/2,channel 就会扩大为原来的 2倍与resnet 网络类似。每个 stage 中的 Swin Transformer Block 都由两个相连的分别以 W-MSA和 SW-MSA为基础的 Transformer Block 构成,通过 Window 和 Shifted Window 机制提高计算性能。最右边两个图为Swim Transformer的每个块结构,类似于ViT的块结构,其核心修改的地方就是将原本的MSA变为WMSA。
Patch Merging
Patch Merging 模块将 尺寸为 H X W 的 Patch 块首先进行拼接并在 channel 维度上进行concatenate 构成了 H/2 x W/2 4C 的特征图,然后再进行 Layer Normalization 操作进行正则化,然后通过一个 Linear 层形成了一个 H/2 x W/2 2C ,完成了特征图的下采样过程。其中size 缩小为原来的 1/2,channel 扩大为原来的 2倍。
这里也可以举个例子,假设我们的输入是4x4大小单通道的特征图,首先我们会隔一个取一个小Patch组合在一起,最后4x4的特征图会行成4个2x2的特征图。接下来将4个Patch进行拼接,现在得到的特征图尺寸为2x2x4。然后会经过一个LN层,这里当然会改变特征图的值,我改变了一些颜色象征性的表示了一下,LN层后特征图尺寸不会改变,仍为2x2x4。最后会经过一个全连接层,将特征图尺寸由2x2x4变为2x2x2。
W-MSA
ViT 网络中的 MSA通过 self-Attention 使得每一个像素点都可以和其他的像素点进行内积从而得到所有像素点的信息,从而获得丰富的全局信息。但是每个像素点都需要和其他像素点进行信息交换计算量巨大,网络的执行效率低下。因此 Swin-T 将 MSA 分个多个固定的 Windows 构成了 W-MSA,每个 Windows 之间的像素点只能与该 Windows 中的其他像素点进行内积从而获得信息,这样便大幅的减小了计算量,提高了网络的运算效率。
MSA和 W-MAS 的计算量如下所示
其中 h、w 和 C 分别代表特征图的高度、!宽度和深度,M 代表每个 Windows 的大小。
假定h =w =64,M =4,C = 96
采用MSA模块的计算复杂度为 4 x 64 x 64 962 2 (64 64)2 96 = 3372220416采用W-MSA模块的计算复杂度为 4 x 64 x 6 962 -2 4 64 64 96 = 163577856可以计算出 W-MSA 节省了3208642560 FLOPs。
SW-MSA
虽然 W-MSA 通过划分 Windows 的方法减少了计算量,但是由于各个 Windows 之间无法进行信息的交互,因此可以看作其“感受野”缩小,无法得到较全局准确的信息从而影响网络的准确度。为了实现不同窗口之间的信息交互,我们可以将窗口滑动,偏移窗口使其包合不同的像素点,然后再进行 W.MSA计算,将两次 W-MSA计算的结果进行连接便可结合两个不同的 Windows 中的像素点所包含的信息从而实现 Windows 之间的信息共通。
偏移窗口的 W-MSA构成了 SW-MSA 模块,其 Windows 在 W-MSA的基础上向右下角偏移了两个Patch,形成了9个大小不一的块,然后使用 cyclic shift 将这9 个块平移拼接成与 W-MSA 对应的4个大小相同的块,再通过 masked MSA 对这 4 个拼接块进行对应的模板计算完成信息的提取,最后通过 reverse cyclic shift 将信息数据 patch 平移回原先的位置。通过 SW-MSA机制完成了偏移窗口的象素点的 MSA计算并实现了不同窗口间像素点的信息交流,从而间接扩大了网络的“感受野”,提高了信息的利用效率
我们仔细说明一下这一部分,上面可能比较抽象,这一块我认为也是Swin Transformer的核心。可以发现通过将窗口进行偏移后,就到达了窗口与窗口之间的相互通信。虽然已经能达到窗口与窗口之间的通信,但是原来的特征图只有4个窗口,经过移动窗口后,得到了9个窗口,窗口的数量有所增加并且9个窗口的大小也不是完全相同,这就导致计算难度增加。因此,作者又提出而了Efficient batchcomputation for shifted configuration,一种更加高效的计算方法。如下图所示:
先将012区域移动到最下方,再将360区域移动到最右方,此时移动完后,4是一个单独的窗口;将5和3合并成一个窗口;7和1合并成一个窗口; 8、6、2和0合并成一个窗口。这样又和原来一样是4个4x4的窗口了,所以能够保证计算量是一样的。这里肯定有人会想,把不同的区域合并在一起(比如5和3)进行MSA,这信息不就乱窜了吗? 是的,为了防止这个问题,在实际计算中使用的是maskedMSA即带蒙板mask的MSA,这样就能够通过设置Mask来隔绝不同区域的信息了
Relative position bias
Swin-T 网络还在 Attention 计算中引入了相对位置偏置机制去提高网络的整体准确率表现,通过引入相对位置偏置机制,其准确度能够提高 1.2%~2.3% 不等。以 2x2 的特征图为例,首先我们需要对特征图的各个块进行绝对位置的编号,得到每个块的绝对位置索引。然后对每个块计算其与其他块之间的相对位置,计算方法为该块的绝对位置索引减去其他块的绝对位置索引,可以得到每个块的相对位置索引矩阵。将每个块的相对位置索引矩阵展平连接构成了整个特征图的相对位置索引矩阵,具体的计算流程如下图所示。
Swin-T并不是使用二维元组形式的相对位置索引矩阵,而是通过将二维元组形式的相对位置索引映射为一维的相对位置偏置(Relative position bias) 来构成相应的矩阵,具体的映射方法如下: 1.将对应的相对位置行索引和列索引分别加上 M-1,2.将行索引乘以 2M-1,3.将行索引和列索引相加,再使用对应的相对位置偏置表(Relative position bias table) 进行映射即可得到最终的相对位置偏置B。具体的计算流程如下所示
如果这一部分看的比较迷糊,也可以简单看看直接从相对位置进行映射,我们就可以得到相对位置偏置
加入了相对位置偏置机制的 Attention 计算公式如下所示
其中 B 即为上述计算得到的相对位置偏置
Swin Transformer 网络结构
下表是原论文中给出的关于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:
win.sz.7x7表示使用的窗 (Windows) 的大小
dim表示feature map的channel深度 (或者说token的向量长度)
head表示多头注意力模块中head的个数
首先我们还是得判断是否可以利用GPU,因为GPU的速度可能会比我们用CPU的速度快20-50倍左右,特别是对卷积神经网络来说,更是提升特别明显。
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Patch Embedding
在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。
具体做法是对原始图片裁成一个个 patch_size * patch_size 的窗口大小,然后进行嵌入。
这里可以通过二维卷积层,将stride,kernelsize设置为patch size大小。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度
class PatchEmbed(nn.Module): def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None): super(PatchEmbed, self).__init__() self.patch_size = patch_size self.in_c = in_c self.embed_dim = embed_dim self.proj = nn.Conv2d( in_c, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): # 如果图片的H,W不是patch_size的整数倍,需要padding _, _, h, w = x.shape if (h % self.patch_size != 0) or (w % self.patch_size != 0): x = F.pad(x, (0, self.patch_size - w % self.patch_size, 0, self.patch_size - h % self.patch_size, 0, 0)) x = self.proj(x) _, _, h, w = x.shape # (b,c,h,w) -> (b,c,hw) -> (b,hw,c) x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, h, w
Patch Merging
该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
在CNN中,则是在每个Stage开始前用 stride=2 的卷积/池化层来降低分辨率
每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素
然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍 (因为H.W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍
class PatchMerging(nn.Module): def __init__(self, dim, norm_layer=nn.LayerNorm): super(PatchMerging, self).__init__() self.dim = dim self.reduction = nn.Linear(4*dim, 2*dim, bias=False) self.norm = norm_layer(4*dim) def forward(self, x, h, w): # (b,hw,c) b, l, c = x.shape # (b,hw,c) -> (b,h,w,c) x = x.view(b, h, w, c) # 如果h,w不是2的整数倍,需要padding if (h % 2 == 1) or (w % 2 == 1): x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) # (b,h/2,w/2,c) x0 = x[:, 0::2, 0::2, :] x1 = x[:, 1::2, 0::2, :] x2 = x[:, 0::2, 1::2, :] x3 = x[:, 1::2, 1::2, :] # (b,h/2,w/2,c)*4 -> (b,h/2,w/2,4c) x = torch.cat([x0, x1, x2, x3], -1) # (b,hw/4,4c) x = x.view(b, -1, 4*c) x = self.norm(x) # (b,hw/4,4c) -> (b,hw/4,2c) x = self.reduction(x) return x
下面是一个示意图 (输入张量N=1,H=W=8,C=1,不包含最后的全连接层调整)
Window Partition/Reverse
window partition 函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W c,划分成num_windows*B,window_size, window_size, C,其中 num windows = H*W/window size*window size),即窗口的个数。而window reverse 函数则是对应的逆过程。这两个函数会在后面的 window Attention 用到。
def window_partition(x, window_size): """ 将feature map按照window_size分割成windows """ b, h, w, c = x.shape # (b,h,w,c) -> (b,h//m,m,w//m,m,c) x = x.view(b, h//window_size, window_size, w//window_size, window_size, c) # (b,h//m,m,w//m,m,c) -> (b,h//m,w//m,m,m,c) # -> (b,h//m*w//m,m,m,c) -> (b*n_windows,m,m,c) windows = (x .permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, window_size, window_size, c)) return windows def window_reverse(x,window_size,h,w): """ 将分割后的windows还原成feature map """ b = int(x.shape[0] / (h*w/window_size/window_size)) # (b,h//m,w//m,m,m,c) x = x.view(b,h//window_size,w//window_size,window_size,window_size,-1) # (b,h//m,w//m,m,m,c) -> (b,h//m,m,w//m,m,c) # -> (b,h,w,c) x = x.permute(0,1,3,2,4,5).contiguous().view(b,h,w,-1) return x class MLP(nn.Module): def __init__(self, in_features, hid_features=None, out_features=None, dropout=0.): super(MLP, self).__init__() out_features = out_features or in_features hid_features = hid_features or in_features self.fc1 = nn.Linear(in_features, hid_features) self.act = nn.GELU() self.drop1 = nn.Dropout(dropout) self.fc2 = nn.Linear(hid_features, out_features) self.drop2 = nn.Dropout(dropout) def forward(self, x): x = self.drop1(self.act(self.fc1(x))) x = self.drop2(self.fc2(x)) return x
Window Attention
这是这篇文章的关键。传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。
我们先简单看下公式
主要区别是在原始计算Attention的公式中的Q,K时加入了相对位置编码。后续实验有证明相对位置编码的加入提升了模型性能。
class WindowAttention(nn.Module): def __init__(self, dim, window_size, n_heads, qkv_bias=True, attn_dropout=0., proj_dropout=0.): super(WindowAttention, self).__init__() self.dim = dim self.window_size = window_size self.n_heads = n_heads self.scale = (dim // n_heads) ** -.5 # ((2m-1)*(2m-1),n_heads) # 相对位置参数表长为(2m-1)*(2m-1) # 行索引和列索引各有2m-1种可能,故其排列组合有(2m-1)*(2m-1)种可能 self.relative_position_bias_table = nn.Parameter( torch.zeros((2*window_size - 1) * (2*window_size - 1), n_heads)) # 构建窗口的绝对位置索引 # 以window_size=2为例 # coord_h = coord_w = [0,1] # meshgrid([0,1], [0,1]) # -> [[0,0], [[0,1] # [1,1]], [0,1]] # -> [[0,0,1,1], # [0,1,0,1]] # (m,) coord_h = torch.arange(self.window_size) coord_w = torch.arange(self.window_size) # (m,)*2 -> (m,m)*2 -> (2,m,m) coords = torch.stack(torch.meshgrid([coord_h, coord_w])) # (2,m*m) coord_flatten = torch.flatten(coords, 1) # 构建窗口的相对位置索引 # (2,m*m,1) - (2,1,m*m) -> (2,m*m,m*m) # 以coord_flatten为 # [[0,0,1,1] # [0,1,0,1]]为例 # 对于第一个元素[0,0,1,1] # [[0],[0],[1],[1]] - [[0,0,1,1]] # -> [[0,0,0,0] - [[0,0,1,1] = [[0,0,-1,-1] # [0,0,0,0] [0,0,1,1] [0,0,-1,-1] # [1,1,1,1] [0,0,1,1] [1,1, 0, 0] # [1,1,1,1]] [0,0,1,1]] [1,1, 0, 0]] # 相当于每个元素的h减去每个元素的h # 例如,第一行[0,0,0,0] - [0,0,1,1] -> [0,0,-1,-1] # 即为元素(0,0)相对(0,0)(0,1)(1,0)(1,1)为列(h)方向的差 # 第二个元素即为每个元素的w减去每个元素的w # 于是得到窗口内每个元素相对每个元素高和宽的差 # 例如relative_coords[0,1,2] # 即为窗口的第1个像素(0,1)和第2个像素(1,0)在列(h)方向的差 relative_coords = coord_flatten[:, :, None] - coord_flatten[:, None, :] # (m*m,m*m,2) relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 论文中提到的,将二维相对位置索引转为一维的过程 # 1. 行列都加上m-1 # 2. 行乘以2m-1 # 3. 行列相加 relative_coords[:, :, 0] += self.window_size - 1 relative_coords[:, :, 1] += self.window_size - 1 relative_coords[:, :, 0] *= 2 * self.window_size - 1 # (m*m,m*m,2) -> (m*m,m*m) relative_pos_idx = relative_coords.sum(-1) self.register_buffer('relative_pos_idx', relative_pos_idx) self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) self.attn_dropout = nn.Dropout(attn_dropout) self.proj = nn.Linear(dim, dim) self.proj_dropout = nn.Dropout(proj_dropout) nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask): b, n, c = x.shape # (b*n_windows,m*m,total_embed_dim) # -> (b*n_windows,m*m,3*total_embed_dim) # -> (b*n_windows,m*m,3,n_heads,embed_dim_per_head) # -> (3,b*n_windows,n_heads,m*m,embed_dim_per_head) qkv = (self.qkv(x) .reshape(b, n, 3, self.n_heads, c//self.n_heads) .permute(2, 0, 3, 1, 4)) # (b*n_windows,n_heads,m*m,embed_dim_per_head) q, k, v = qkv.unbind(0) q = q * self.scale # (b*n_windows,n_heads,m*m,m*m) attn = (q @ k.transpose(-2, -1)) # (m*m*m*m,n_heads) # -> (m*m,m*m,n_heads) # -> (n_heads,m*m,m*m) # -> (b*n_windows,n_heads,m*m,m*m) + (1,n_heads,m*m,m*m) # -> (b*n_windows,n_heads,m*m,m*m) relative_pos_bias = (self.relative_position_bias_table[self.relative_pos_idx.view(-1)] .view(self.window_size*self.window_size, self.window_size*self.window_size, -1)) relative_pos_bias = relative_pos_bias.permute(2, 0, 1).contiguous() attn = attn + relative_pos_bias.unsqueeze(0) if mask is not None: # mask: (n_windows,m*m,m*m) nw = mask.shape[0] # (b*n_windows,n_heads,m*m,m*m) # -> (b,n_windows,n_heads,m*m,m*m) # + (1,n_windows,1,m*m,m*m) # -> (b,n_windows,n_heads,m*m,m*m) attn = (attn.view(b//nw, nw, self.n_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)) # (b,n_windows,n_heads,m*m,m*m) # -> (b*n_windows,n_heads,m*m,m*m) attn = attn.view(-1, self.n_heads, n, n) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_dropout(attn) # (b*n_windows,n_heads,m*m,embed_dim_per_head) # -> (b*n_windows,m*m,n_heads,embed_dim_per_head) # -> (b*n_windows,m*m,total_embed_dim) x = (attn @ v).transpose(1, 2).reshape(b, n, c) x = self.proj(x) x = self.proj_dropout(x) return x
首先输入张量形状为 numwindows*B,window_size * window size,C然后经过 self.qkv 这个全连接层后,进行reshape,调整轴的顺序,得到形状为 3,numwindows*B,num heads,window_size*window_size, c//num heads ,并分配给q,k,v。根据公式,我们对g 乘以一个 scale 缩放系数,然后与 (为了满足矩阵乘要求,需要将最后两维度调换) 进行相乘。得到形状为 (numwindows*B,num heads,window size*window size,vindow size*window size)的 attn 张量
之前我们针对位置编码设置了个形状为(2*window size-1*2*window size-1,numHeads)的可学习变量。我们用计算得到的相对编码位置索引 self.relative_position_index 选取,得到形状为(window_size*window_size,window_size*window size,numHeads)的编码,加到attn张量上暂不考虑mask的情况,剩下就是跟transformer一样的softmax,dropout,与v矩阵乘,再经过-吴全连接层和dropout
Pytorch CIFAR10图像分类 Swin Transformer篇(二):https://developer.aliyun.com/article/1410618