从一个参数说起✦
在图像分类任务中,主干网络是视觉神经网络中进行图像特征提取的主体,常见的算法包括我们耳熟能详的 ResNet、Vision Transformer 等。
不知道大家是否注意到,用于图像分类的主干网络中,基于 CNN 结构的网络,通常不需要我们指定输入图像的尺寸,同时,同一个主干网络就能够处理各种尺寸的图像输入。但基于 Transformer 结构的主干网络,就往往需要我们在搭建网络时指定输入的图像尺寸参数 —— img_size,而且网络的前向推理输入也必须是符合这一尺寸的图像。
那么,为什么 Transformer 结构的网络中需要指定输入的图像尺寸呢?我们能否移除这个限制,让网络动态地支持各种尺寸的输入图像呢?这对于一些下游任务有重要的作用,也已经有了一些成熟的解决方案。在最新版的 MMClassification 中,我们将这一功能扩展到了各种基于 Transformer 结构的主干网络中,实现了分类任务与下游任务主干网络的统一。
接下来,就让我们了解一下,Transformer 结构网络支持动态输入尺寸的阻碍与解决方法。
“罪魁祸首”——位置编码✦
说起 Transformer 结构,大家最先想到的关键结构大概率是注意力模块,但这里,问题并不出在注意力模块中,因为注意力模块天然地支持动态尺寸输入。让我们看下这张经典的 ViT 结构图:
首先,我们会将输入图片按照一个固定的 patch size 切分成若干个 patch。之后每个图像 patch 经过一个线性映射得到对应的一个特征向量。这一个个特征向量如果按照其对应 patch 在图像上的位置排列,就是一张图像经过编码后的特征图,其长和宽分别等于原图在纵向和横向切分成了多少个 patch。之后,我们需要给这张特征图加上位置编码(position embedding),以体现每个 patch 在图像上的位置。
当输入图片尺寸发生变化时,由于每个 patch 的尺寸固定,图片切分出的 patch 数就会发生变化。表现在上述特征图中,就是特征图的尺寸发生了变化。这样一来,我们原本位置编码图的尺寸就和图像特征图的尺寸对不上了,无法进行后续的计算。
找到了问题所在,解决的方法也就顺理成章了。位置编码代表的是 patch 所在位置的附加信息,那么如果和图像特征图的尺寸不匹配,只需要使用双三次插值法(Bicubic)对位置编码图进行插值缩放,缩放到与图像特征图一致的尺寸,就同样可以表现每个 patch 在图片中的位置信息。
import torch import torch.nn.functional as F # 原始位置编码 pos_embed = torch.rand(1, 197, 64) # 原始图像尺寸下,长和宽方向的 patch 数 src_shape = (14, 14) # 输入图像尺寸下,长和宽方向的 patch 数 dst_shape = (16, 16) # 额外编码数,在 ViT 中,为 1,指 class embedding;在 DeiT 中为 2 num_extra_tokens = 1 _, L, C = pos_embed.shape src_h, src_w = src_shape # 位置编码第二个维度大小应当等于 patch 数 + 额外编码数 assert L == src_h * src_w + num_extra_tokens # 拆分额外编码和纯位置编码 extra_tokens = pos_embed[:, :num_extra_tokens] src_weight = pos_embed[:, num_extra_tokens:] # 将位置编码组织成 (1, C, H, W) 形式,其中 C 为通道数 src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) # 进行双三次插值 dst_weight = F.interpolate(src_weight, size=dst_shape, mode='bicubic') # 重组位置编码为(1,H*W, C)形式,再拼接上额外编码,即获得新的位置编码 dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) pos_embed = torch.cat((extra_tokens, dst_weight), dim=1)
在官方实现中,这一方法已经有了一些应用。不过源码只把这一方法用于微调模型时,加载和微调模型输入尺寸不同的预训练模型权重。而在 MMClassificaiton 中,我们将这一方法应用于每次模型的前向推理,使每次推理都可以应对不同尺寸的图像输入。
需要提醒的是,就像缩放照片会损失信息,这种对位置编码的插值也不是无损的,建议输入图像的尺度变化不要过大,同时需要在动态尺度输入下进行新的微调训练。
下面一个例子,展示了在 MMClassification 中使用 ViT 模型处理不同尺寸输入的流程:
import torch from mmcls.models import build_backbone cfg = dict(type='VisionTransformer', arch='base') vit_model = build_backbone(cfg) inputs = torch.rand(1, 3, 224, 224) patch_embed, cls_token = vit_model(inputs)[-1] # 获取模型最后一层输出 assert patch_embed.shape == (1, 768, 14, 14) inputs = torch.rand(1, 3, 256, 384) patch_embed, cls_token = vit_model(inputs)[-1] assert patch_embed.shape == (1, 768, 16, 24) # 输入尺寸不同,输出特征图的尺寸也不同
特殊的 Swin-Transformer✦
Position embedding 的问题遍布于经典 ViT 结构的主干网络中,但并不存在于 Swin-Transformer 中。对 Swin-Transformer 有了解的读者应该知道,在 Swin-Transformer 中,没有使用绝对位置编码,也即上文所说的那种与输入图像 patch 一一对应的位置编码;而是配合窗口注意力机制,使用了一种局限于窗口内部的相对位置编码机制。当我们改变输入图像的大小,可能会改变窗口的数量,但并不会影响窗口内部的相对位置编码。
那么 Swin-Transformer 是否天然地具备处理动态输入尺寸的能力呢?其实不尽然,在官方提供的分类 Swin-Transformer 实现中,我们依然需要指定输入图像的尺寸。这涉及到 Swin-Transformer 中的 shfit-window 注意力计算机制。
如上图所示,每个灰格代表一个图像 patch 对应的特征向量,而蓝色的格子则代表一个分窗,整张图就是图像的特征图。因为窗口偏移(shift)的原因,原本 4x4 的窗口大小,在边缘区域变成了一些更小的窗口。在 Swin-Transformer 中,为了高效计算这种情况下的窗口注意力,首先使用 torch.roll 函数,将原本的图像特征图循环偏移成右图所示的排布。之后,我们将这些原本小于 4x4 的边缘窗口组合,如 H 和 B 组合, I、G、C、A 组合,将所有窗口都拼凑成立了 4x4 的窗口。
但是如图的 H 和 B 虽然为了高效计算而临时组队成了一个窗口,但 H 窗口的特征向量不应该能注意到 B 窗口的特征向量。因此需要一个 mask,在计算属于 H 窗口的特征向量的注意力时,这个 mask 能够屏蔽属于 B 窗口的特征向量,使得 H 窗口只注意 H 窗口, B 窗口只注意 B 窗口。
为了便于理解 mask 的生成方式,我们以一个更小的特征图(4x4)及更小的窗口大小(2x2)为例,如下图所示,对特征图进行分窗,生成了 9 个窗口,对特征图进行偏移,并组合部分分窗后,生成了 4 个用于计算的分窗。这里每个窗口都对应了一种窗口组合情况,因此需要使用不同的 mask 来计算注意力。这里,我们以 attention_masks[1] 为例,其为一个 4 * 4 的矩阵,其中第 1 行只有第 1 列和第 3 列为白色,表示计算特征 ① 的注意力时,只考虑 ① 和 ③ 特征。
显而易见的是,需要生成多少 mask,取决于分窗后有多少个窗口;每个 mask 的内容,取决于对应窗口内的边缘窗口组合形式。而如果输入图片的尺寸发生变化,那么整体的特征图尺寸、分出的窗口数量也会发生变化,进而影响 mask 的计算。因此,如果要支持动态的输入尺寸,必须同样动态地生成这些 mask。
幸运的是,这种动态生成 mask 的计算量不高,也不会涉及到插值等操作。通过在前向推理时根据输入图像尺寸动态生成这些 mask,MMClassification 同样支持了 Swin-Transformer 的动态输入尺寸。
解决了以上两个问题,就可以使绝大部分 Transformer 结构的视觉主干网络支持动态的输入图像尺寸。
文章来源:【OpenMMLab】
2022-03-18 18:01