以动制动 | Transformer 如何处理动态输入尺寸

简介: 为什么 Transformer 结构的网络中需要指定输入的图像尺寸呢?我们能否移除这个限制,让网络动态地支持各种尺寸的输入图像呢?这对于一些下游任务有重要的作用,也已经有了一些成熟的解决方案。在最新版的 MMClassification 中,我们将这一功能扩展到了各种基于 Transformer 结构的主干网络中,实现了分类任务与下游任务主干网络的统一。

从一个参数说起



在图像分类任务中,主干网络是视觉神经网络中进行图像特征提取的主体,常见的算法包括我们耳熟能详的 ResNet、Vision Transformer 等。


不知道大家是否注意到,用于图像分类的主干网络中,基于 CNN 结构的网络,通常不需要我们指定输入图像的尺寸,同时,同一个主干网络就能够处理各种尺寸的图像输入。但基于 Transformer 结构的主干网络,就往往需要我们在搭建网络时指定输入的图像尺寸参数 — img_size,而且网络的前向推理输入也必须是符合这一尺寸的图像。


那么,为什么 Transformer 结构的网络中需要指定输入的图像尺寸呢?我们能否移除这个限制,让网络动态地支持各种尺寸的输入图像呢?这对于一些下游任务有重要的作用,也已经有了一些成熟的解决方案。在最新版的 MMClassification 中,我们将这一功能扩展到了各种基于 Transformer 结构的主干网络中,实现了分类任务与下游任务主干网络的统一。


接下来,就让我们了解一下,Transformer 结构网络支持动态输入尺寸的阻碍与解决方法。


“罪魁祸首”——位置编码



说起 Transformer 结构,大家最先想到的关键结构大概率是注意力模块,但这里,问题并不出在注意力模块中,因为注意力模块天然地支持动态尺寸输入。让我们看下这张经典的 ViT 结构图:640.png

首先,我们会将输入图片按照一个固定的 patch size 切分成若干个 patch。之后每个图像 patch 经过一个线性映射得到对应的一个特征向量。这一个个特征向量如果按照其对应 patch 在图像上的位置排列,就是一张图像经过编码后的特征图,其长和宽分别等于原图在纵向和横向切分成了多少个 patch。之后,我们需要给这张特征图加上位置编码(position embedding),以体现每个 patch 在图像上的位置。


当输入图片尺寸发生变化时,由于每个 patch 的尺寸固定,图片切分出的 patch 数就会发生变化。表现在上述特征图中,就是特征图的尺寸发生了变化。这样一来,我们原本位置编码图的尺寸就和图像特征图的尺寸对不上了,无法进行后续的计算。


找到了问题所在,解决的方法也就顺理成章了。位置编码代表的是 patch 所在位置的附加信息,那么如果和图像特征图的尺寸不匹配,只需要使用双三次插值法(Bicubic)对位置编码图进行插值缩放,缩放到与图像特征图一致的尺寸,就同样可以表现每个 patch 在图片中的位置信息。

import torch
import torch.nn.functional as F
# 原始位置编码
pos_embed = torch.rand(1, 197, 64)
# 原始图像尺寸下,长和宽方向的 patch 数
src_shape = (14, 14)
# 输入图像尺寸下,长和宽方向的 patch 数
dst_shape = (16, 16)
# 额外编码数,在 ViT 中,为 1,指 class embedding;在 DeiT 中为 2
num_extra_tokens = 1
_, L, C = pos_embed.shape
src_h, src_w = src_shape
# 位置编码第二个维度大小应当等于 patch 数 + 额外编码数
assert L == src_h * src_w + num_extra_tokens
# 拆分额外编码和纯位置编码
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
# 将位置编码组织成 (1, C, H, W) 形式,其中 C 为通道数
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
# 进行双三次插值
dst_weight = F.interpolate(src_weight, size=dst_shape, mode='bicubic')
# 重组位置编码为(1,H*W, C)形式,再拼接上额外编码,即获得新的位置编码
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
pos_embed = torch.cat((extra_tokens, dst_weight), dim=1)

在官方实现中,这一方法已经有了一些应用。不过源码只把这一方法用于微调模型时,加载和微调模型输入尺寸不同的预训练模型权重。而在 MMClassificaiton 中,我们将这一方法应用于每次模型的前向推理,使每次推理都可以应对不同尺寸的图像输入。


需要提醒的是,就像缩放照片会损失信息,这种对位置编码的插值也不是无损的,建议输入图像的尺度变化不要过大,同时需要在动态尺度输入下进行新的微调训练。


下面一个例子,展示了在 MMClassification 中使用 ViT 模型处理不同尺寸输入的流程:

import torch
from mmcls.models import build_backbone
cfg = dict(type='VisionTransformer', arch='base')
vit_model = build_backbone(cfg)
inputs = torch.rand(1, 3, 224, 224)
patch_embed, cls_token = vit_model(inputs)[-1]  # 获取模型最后一层输出
assert patch_embed.shape == (1, 768, 14, 14)
inputs = torch.rand(1, 3, 256, 384)
patch_embed, cls_token = vit_model(inputs)[-1]
assert patch_embed.shape == (1, 768, 16, 24)  # 输入尺寸不同,输出特征图的尺寸也不同


特殊的 Swin-Transformer



Position embedding 的问题遍布于经典 ViT 结构的主干网络中,但并不存在于 Swin-Transformer 中。对 Swin-Transformer 有了解的读者应该知道,在 Swin-Transformer 中,没有使用绝对位置编码,也即上文所说的那种与输入图像 patch 一一对应的位置编码;而是配合窗口注意力机制,使用了一种局限于窗口内部的相对位置编码机制。当我们改变输入图像的大小,可能会改变窗口的数量,但并不会影响窗口内部的相对位置编码。


那么 Swin-Transformer 是否天然地具备处理动态输入尺寸的能力呢?其实不尽然,在官方提供的分类 Swin-Transformer 实现中,我们依然需要指定输入图像的尺寸。这涉及到 Swin-Transformer 中的 shfit-window 注意力计算机制。

640.png

如上图所示,每个灰格代表一个图像 patch 对应的特征向量,而蓝色的格子则代表一个分窗,整张图就是图像的特征图。因为窗口偏移(shift)的原因,原本 4x4 的窗口大小,在边缘区域变成了一些更小的窗口。在 Swin-Transformer 中,为了高效计算这种情况下的窗口注意力,首先使用 torch.roll 函数,将原本的图像特征图循环偏移成右图所示的排布。之后,我们将这些原本小于 4x4 的边缘窗口组合,如 H 和 B 组合, I、G、C、A 组合,将所有窗口都拼凑成立了 4x4 的窗口。


但是如图的 H 和 B 虽然为了高效计算而临时组队成了一个窗口,但 H 窗口的特征向量不应该能注意到 B 窗口的特征向量。因此需要一个 mask,在计算属于 H 窗口的特征向量的注意力时,这个 mask 能够屏蔽属于 B 窗口的特征向量,使得 H 窗口只注意 H 窗口, B 窗口只注意 B 窗口。


为了便于理解 mask 的生成方式,我们以一个更小的特征图(4x4)及更小的窗口大小(2x2)为例,如下图所示,对特征图进行分窗,生成了 9 个窗口,对特征图进行偏移,并组合部分分窗后,生成了 4 个用于计算的分窗。这里每个窗口都对应了一种窗口组合情况,因此需要使用不同的 mask 来计算注意力。这里,我们以 attention_masks[1] 为例,其为一个 4 * 4 的矩阵,其中第 1 行只有第 1 列和第 3 列为白色,表示计算特征 ① 的注意力时,只考虑 ① 和 ③ 特征。

640.png


显而易见的是,需要生成多少 mask,取决于分窗后有多少个窗口;每个 mask 的内容,取决于对应窗口内的边缘窗口组合形式。而如果输入图片的尺寸发生变化,那么整体的特征图尺寸、分出的窗口数量也会发生变化,进而影响 mask 的计算。因此,如果要支持动态的输入尺寸,必须同样动态地生成这些 mask。


幸运的是,这种动态生成 mask 的计算量不高,也不会涉及到插值等操作。通过在前向推理时根据输入图像尺寸动态生成这些 mask,MMClassification 同样支持了 Swin-Transformer 的动态输入尺寸。


解决了以上两个问题,就可以使绝大部分 Transformer 结构的视觉主干网络支持动态的输入图像尺寸。


文章来源:【OpenMMLab

 2022-03-18 18:01


目录
相关文章
|
2天前
|
编解码 人工智能 测试技术
无需训练,这个新方法实现了生成图像尺寸、分辨率自由
【4月更文挑战第25天】研究人员提出FouriScale方法,解决了扩散模型在生成高分辨率图像时的结构失真问题。通过膨胀卷积和低通滤波,该方法实现不同分辨率下图像的结构和尺度一致性,无需重新训练模型。实验显示FouriScale在保持图像真实性和完整性的同时,能生成任意尺寸的高质量图像,尤其在处理高宽比图像时表现出色。尽管在极高分辨率生成上仍有局限,但为超高清图像合成技术提供了新思路。[链接: https://arxiv.org/abs/2403.12963]
33 5
|
2天前
|
算法
【MFAC】基于全格式动态线性化的无模型自适应控制
【MFAC】基于全格式动态线性化的无模型自适应控制
|
8月前
|
机器学习/深度学习 传感器 数据可视化
【免费】以 3D 形式显示热图、高程或天线响应模式表面数据附matlab代码
【免费】以 3D 形式显示热图、高程或天线响应模式表面数据附matlab代码
|
移动开发 文字识别 算法
论文推荐|[PR 2019]SegLink++:基于实例感知与组件组合的任意形状密集场景文本检测方法
本文简要介绍Pattern Recognition 2019论文“SegLink++: Detecting Dense and Arbitrary-shaped Scene Text by Instance-aware Component Grouping”的主要工作。该论文提出一种对文字实例敏感的自下而上的文字检测方法,解决了自然场景中密集文本和不规则文本的检测问题。
1883 0
论文推荐|[PR 2019]SegLink++:基于实例感知与组件组合的任意形状密集场景文本检测方法
|
2天前
|
算法
【MFAC】基于紧格式动态线性化的无模型自适应迭代学习控制
【MFAC】基于紧格式动态线性化的无模型自适应迭代学习控制
【MFAC】基于紧格式动态线性化的无模型自适应迭代学习控制
|
9月前
|
编解码 算法 数据可视化
【多重信号分类】超分辨率测向方法——依赖于将观测空间分解为噪声子空间和源/信号子空间的方法具有高分辨率(HR)并产生准确的估计(Matlab代码实现)
【多重信号分类】超分辨率测向方法——依赖于将观测空间分解为噪声子空间和源/信号子空间的方法具有高分辨率(HR)并产生准确的估计(Matlab代码实现)
|
7月前
|
前端开发 芯片
【芯片前端】保持代码手感——不重叠序列检测
【芯片前端】保持代码手感——不重叠序列检测
|
8月前
已知堆叠体的三视图,求堆叠体体积:俯视图标注法的使用
已知堆叠体的三视图,求堆叠体体积:俯视图标注法的使用
47 0
|
9月前
|
机器学习/深度学习 编解码 数据可视化
ConvNeXt V2:与屏蔽自动编码器共同设计和缩放ConvNets,论文+代码+实战
ConvNeXt V2:与屏蔽自动编码器共同设计和缩放ConvNets,论文+代码+实战
|
12月前
|
计算机视觉
当「分割一切」遇上图像修补:无需精细标记,单击物体实现物体移除、内容填补、场景替换(1)
当「分割一切」遇上图像修补:无需精细标记,单击物体实现物体移除、内容填补、场景替换