Vision Transformer 必读系列之图像分类综述(二): Attention-based(上)

本文涉及的产品
图片翻译,图片翻译 100张
语种识别,语种识别 100万字符
云原生大数据计算服务MaxCompute,500CU*H 100GB 3个月
简介: Transformer 结构是 Google 在 2017 年为解决机器翻译任务(例如英文翻译为中文)而提出,从题目中可以看出主要是靠 Attention 注意力机制,其最大特点是抛弃了传统的 CNN 和 RNN,整个网络结构完全是由 Attention 机制组成。为此需要先解释何为注意力机制,然后再分析模型结构。


前言



Vision Transformer 必读系列之图像分类综述(一):概述 一文中,我们对 Vision Transformer 在图像分类中的发展进行了概述性总结。本文则对其中涉及的 Attention-based 部分进行详细说明。下一篇文章则会对概述中涉及的其他部分进行说明。
ViT 进展汇总思维导图如下图所示:

640.png


1. Transformer



论文题目:Attention is All You Need论文地址:https://arxiv.org/abs/1706.03762

Transformer 结构是 Google 在 2017 年为解决机器翻译任务(例如英文翻译为中文)而提出,从题目中可以看出主要是靠 Attention 注意力机制,其最大特点是抛弃了传统的 CNN 和 RNN,整个网络结构完全是由 Attention 机制组成。为此需要先解释何为注意力机制,然后再分析模型结构。


1.1 Attention 注意力机制


人生来就有注意力机制,看任何画面,我们会自动聚焦到特定位置特定物体上。此处的 Attention 机制也是同一个含义,对于需要的任何模态,不管是图像、文本、点云还是其他,我们都希望网络通过训练能够自动聚焦到有意义的位置,例如图像分类和检测任务,网络通过训练能够自动聚焦到待分类物体和待检测物体上。注意力机制不是啥新鲜概念,视觉算法中早已广泛应用,典型的如 SENet。

640.png

利用 Squeeze-and-Excitation 模块计算注意力权重概率分布,然后作用于特征图上实现对每个通道重加权功能。

可以举一个更简单的例子,假设有一个训练好的分类网络,输入一张图片,训练好的分类网络权重 W 和图片 X 进行注意力计算,从 X 中提取能够有助于分类的特征,该特征最终可以作为类别分类依据。W 和 X 都是矩阵,要想利用 W 矩阵来达到重加权 X 目的,等价于计算 W 和 X 的相似度(点乘),然后将该相似度变换为权重概率分布,再次作用于 X 上就可以以一个简单猫狗二分类例子说明。网络最终输出是 2x1 的向量,第一个数大则表示猫类别,否则为狗类别,假设网络已经训练好了,其 W 为 shape 为 2x1 的向量,值为 [[0.1, 0.5]],X 表示输入图片 shape 也是 2x1,其值为 [[0.1, 0.8]],可以看出其类别是狗,采用如下的计算步骤即可正确分类:


  • W 和 X 的转置相乘,即计算 W 中每个值和 X 中每个值的相似度,得到 2x2 矩阵,值为 [[0.01,0.08], [0.05,0.4]]。
  • 第二个维度进行 Softmax,将其转化为概率权重图即为 [[0.4825, 0.5175], [0.4134, 0.5866]]。
  • 将上述概率权重乘以 X,得到 shape 为 2x1 输出,值为 [[0.4622, 0.5106]]。
  • 此时由于第二个值大,所以正确分类为狗。


X 是含有狗的图片矩阵,能够正确分类的前提是训练好的 W 矩阵中第二个数大于第一个数。可以简单理解上述过程是计算 W 和 X 的相似度,如果两个向量相似(都是第二个比第一个数大),那么就分类为狗,否则就分类为猫


import torch
W = torch.tensor([[0.1, 0.5]]).view([1, -1, 1])
X = torch.tensor([[0.1, 0.8]]).view([1, -1, 1])
# 1 计算两个向量相似度
attn_output_weights = torch.bmm(W, X.transpose(1, 2))
# 2 转换为概率分布
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
# 3 注意力加权
cls = torch.bmm(attn_output_weights, X)[0]
print(cls)

上述计算过程可以用如下公式表示:

640.png

对应到上面例子,Q 就是训练好的 W 矩阵,K 是图片输入,V 和 K 相等,其通用解释为利用 Q 查询矩阵和 K 矩阵进行相似度计算,然后转换为概率分布,此时概率值大的位置表示两者相似度大的部分,然后将概率分布乘上 V 值矩阵,从而用注意力权重分布加权了 V 矩阵,也就改变了 V 矩阵本身的分布。如果注意力机制训练的很好,那么提取的 V 应该就是我们想要的信息。分母 d_k 的平方根是为了避免梯度消失,当向量值非常大的时候,Softmax 函数会将几乎全部的概率分布都分配给了最大值对应的位置,也就是说所谓的锐化,通过除以分母可以有效避免梯度消失问题,稳定训练过程。

上述公式就是论文中提出的最重要的 Scaled Dot-Product Attention计算公式,先利用点乘计算 QK 矩阵的相似度,除以分母 d_k 平方根进行 Scaled 操作,然后 Softmax 操作将其转换为概率乘以 V 实现 Attention 功能。需要注意:为了让上面公式不会报错,其 Shape 关系必须为 Q - (N, M),K - (P, M),V - (P, G), 一般来说 K 和 V Shape 相同,但是 Q Shape 不一定和 K 相同。通过灵活地改变这些维度就可以控制注意力层的计算复杂度,后续大部分改进算法都有利用这一点


1.2 Transformer 结构分析


Transformer 是为了解决机器翻译任务而提出。机器翻译是一个历史悠久的问题,可以理解为序列转序列问题,也就是我们常说的 seq2seq 结构,解决这类问题一般是采用 encoder-decoder 结构,Transformer 也沿用了这种结构。翻译任务一个常规的解决方案如下所示:

640.png

对应到 Transformer 中的一个更具体的结构为:

640.png

主要包括编码器和解码器组件,编码器包括自注意力模块(QKV 来自同一个输入)和前向网络,解码器和编码器类似,只不过内部多了编码器和解码器交互的交叉注意力模块。
通常来说,标准的 Transformer 包括 6 个编码器和 6 个解码器串行。

  1. 编码器内部接收源翻译输入序列,通过自注意力模块提取必备特征,通过前向网络对特征进行进一步抽象。
  2. 解码器端输入包括两个部分,一个是目标翻译序列经过自注意力模块提取的特征,一个是编码器提取的全局特征,这两个输入特征向量会进行交叉注意力计算,抽取有利于目标序列分类的特征,然后通过前向网络对特征进行进一步抽象。
  3. 堆叠多个编码器和解码器,下一个编解码器接收来自上一个编解码的输出,构成串行结构不断抽取,最后利用解码器输出进行分类即可。


Transformer 完整结构如下所示:


640.png

编码器基本组件包括:源句子词嵌入模块 Input Embedding、位置编码模块 Positional Encoding、多头自注意力模块 Muti-Head Attention、前向网络模块 Feed Forward 以及必要的 Norm、Dropout 和残差模块。
解码器基本组件类似包括:目标句子词嵌入模块  Output Embedding、位置编码模块 Positional Encoding、带 mask 的自注意力模块 Masked Muti-Head Attention、交叉互注意力模块  Muti-Head Attention、前向网络模块  Feed Forward 、分类头模块 Linear+Softmax 以及必要的 Norm、Dropout 和残差模块。
由于本文重点是分析视觉方面的 Transformer,故没有必要对机器翻译过程进行深入解析,读者只需要理解每个模块的作用即可,而且视觉分类 Transformer 任务和 NLP 机器翻译任务不一样,实际上也不需要解码器模块,相比 NLP 任务会简单很多。


1.2.1 编码器基本组件


(1) 源句子词嵌入模块 Input Embedding
机器翻译是句子输入,句子输出,每个句子由单词构成,将句子编码成程序可以理解的向量过程就叫做词嵌入过程,也就是常说的 Word2Vec,对应到图像中称为 Token 化过程即如何将图像转换为更具语义的 Token,Token 概念会在 ViT 中详细描述。
(2) 多头自注意力模块 Muti-Head Attention

640.png

在 1.1 小节已经详细说明了注意力计算过程。左边是最简单的 Scaled Dot-Product Attention,单纯看上图你可以发现没有任何可学习参数,那么其存在的意义是啥?实际上可学习参数在 QKV 映射矩阵中,在自注意力模块中会对输入的向量分别乘上可学习映射矩阵 W_Q、W_K 和 W_V,得到真正的 Q、K 和 V 输入,然后再进行 Scaled Dot-Product Attention 计算。
为了增加注意力特征提取的丰富性,不会陷入某种局部特性中,一般会在注意力层基础上(单头注意力层)引入多个投影头,将 QKV 特征维度平均切分为多个部分(一般分成 8 部分),每个部分单独进行自注意力计算,计算结果进行拼接 。在特征维度平均切分,然后单独投影、计算,最后拼接可以迫使提取的注意力特征更加丰富。也就是上面的多头注意力模块 Multi-Head Attention。
(3) 前向网络模块 Feed Forward
前向网络模块主要是目的是对特征进行变换,其单独作用于每个序列(只对最后一个特征维度进行变换)。由于没有结构图,故直接贴相关代码,包括两个 Position-wise FC 层、激活层、Dropout层和 LayerNorm 层。

 
         



class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        # 两个 fc 层,对最后的维度进行变换
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        residual = x
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual
        x = self.layer_norm(x)
        return x

(4) Norm、Dropout 和残差模块
在每个注意力层后面和前向网络后都会接入 Dropout 、残差模块和  Layer Norm 模块。这些必要的措施对整个算法的性能提升非常关键。至于为啥用 Layer Norm 而不是 Batch Norm原因是机器翻译任务输入的句子不一定是同样长的,组成 Batch 训练时候会存在大量 Padding 操作,如果在 Batch 这个维度进行 Norm 会出现大量无效统计,导致 Norm 值不稳定,而  Layer Norm 是对每个序列单独计算,不考虑 Batch 影响,这样比较符合不定长序列任务学习任务。当然如果换成图像分类任务,则可以考虑使用 BN 层,后续有算法是直接采用 BN 的。
(5) 位置编码 Positional Encoding
考虑一个分类任务,输入一段句子判断是疑问句还是非疑问句?现在有两条语句分别是:

  • 不准在地铁上吃东西
  • 在地铁上吃东西准不


自注意力层的计算不会考虑字符间的顺序,因为每个字符都是单独和全局向量算相似度,也就是说上面两个句子输入进行注意力计算,输出的向量值是相同的,只不过相对位置有变化。如果我们对输出向量求和后值大于 0 还是小于 0 作为分类依据,那么上面两个句子输出相加值是完全相同的,那就始终无法区分到底是疑问句还是非疑问句,这就是我们常说的 Transformer 具有位置不变性。要解决这个问题,只需要让模型知道输入语句是有先后顺序的,位置编码可以解决这个问题。
加入位置信息的方式非常多,最简单的可以是直接将输入序列中的每个词按照绝对坐标 0,1,2 编码成相同长度的向量,然后和词向量相加即可。作者实际上提出了两种方式:

  • 网络自动学习,直接全 0 初始化向量,然后和词向量相加,通过网络学习来学习位置信息。
  • 自己定义规则,规则自己定,只要能够区分输入词顺序即可,常用的是 sincos 编码。


实际训练选择哪一种位置编码方式发现效果一致,但是不管哪一种位置编码方式都应该充分考虑在测试时候序列不定长问题,可能会出现测试时候非常长的训练没有见过的长度序列,后面会详细说明


1.2.2 解码器基本组件


其大部分组件都和编码器相同,唯一不同的是自注意力模块带有 mask,还额外引入了一个交叉注意力模块以及分类头模块。
(1) 带 mask 的自注意力模块
注意这个模块的输入是目标序列转化为词向量后进行自注意力计算。机器翻译是一个 seq2seq 任务,其真正预测是:最开始输入开始解码 token 代表解码开始,解码出第一个词后,将前面已经解码出的词再次输入到解码器中,按照顺序一个词一个词解码,最后输出解码结束 token,表示翻译结束。
也就是当解码时,在解码当前词的时候实际上不知道下一个词是啥,但在训练时,是将整个目标序列一起输入,然而注意力计算是全局的,每个目标单词都会和整个目标句子计算自注意力,这种训练和测试阶段的不一致性无法直接用于预测。为此我们需要在训练过程中计算当前词自注意力时候手动屏蔽掉后面的词,让模型不知道后面词。具体实现就是输入一个 mask 来覆盖掉后面的词。
由于这种特性是只存在于 NLP 领域,图片中不存在,故不再进行更深入分析。
(2) 交叉注意力
交叉注意力模块和自注意力模块相同,只不过其 QKV 来源不同,Q 来自解码器,KV 来自编码器,交叉注意力模块会利用 Q 来提取编码器提取的特征 KV,然后进行分类。
(3) 分类头
分类头就是普通的线性映射,转换输出维度为分类个数,然后采用 CE Loss 进行训练即可。


1.3 总结


Transformer 结构内部存在多个组件,但是最核心的还是注意力模块,在原始论文中作者也引入了大量的可视化分析来证明注意力模块的作用,有兴趣的建议阅读原文。可能作者自己也没有想到这篇论文会在视觉领域引起另一个全新的风尚,开辟出一条新的看起来前途一片光明的道路。

640.png

图片来自 A Survey of Visual Transformers

网址:https://arxiv.org/abs/2111.06091


2. Vision Transformer



论文题目:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale论文地址:https://arxiv.org/abs/2010.11929
ViT 是第一篇成功将 Transformer 引入到视觉领域且成功的尝试,开辟了视觉 Transformer 先河。其结构图如下所示:

640.png

其做法非常简单,简要概况为:

  • 将图片分成无重叠的固定大小 Patch (例如 16x16),然后将每个 Patch 拉成一维向量, n 个 Patch 相当于 NLP 中的输入序列长度(假设输入图片是 224x224,每个 patch 大小是 16x16,则 n 是 196),而一维向量长度等价于词向量编码长度(假设图片通道是 3, 则每个序列的向量长度是 768)。
  • 考虑到一维向量维度较大,需要将拉伸后的 Patch 序列经过线性投影 (nn.Linear) 压缩维度,同时也起到特征变换功能,这两个步骤可以称为图片 Token 化过程 (Patch Embedding)。
  • 为了方便后续分类,作者还额外引入一个可学习的 Class Token,该 Token 插入到图片 token 化后所得序列的开始位置。
  • 将上述序列加上可学习的位置编码输入到 N 个串行的 Transformer 编码器中进行全局注意力计算和特征提取,其中内部的多头自注意模块用于进行 Patch 间或者序列间特征提取,而后面的 Feed Forward (Linear+ GELU+Dropout+ Linear+ Dropout) 模块对每个 Patch 或者序列进行特征变换。
  • 将最后一个 Transformer 编码器输出序列的第 0 位置( Class Token 位置对应输出)提取出来,然后后面接 MLP 分类后,然后正常分类即可。


可以看出,图片分类无需 Transformer 解码器,且编码器几乎没有做任何改动,针对图像分类任务,只需单独引入一个 Image to Token 操作和 Class Token 的概念即可。
如何理解 Token?个人觉得任何包括图片更加高级的语义向量都可以叫做 Token,这个概念在 NLP 中应用非常广泛,表征离散化后的高级单词语义,在图像中则可以认为是将图像转化为离散的含更高级语义的向量。

ViT 证明纯 Transformer 也可以取得非常好的效果,相比 CNN 在数据量越大的情况下优势更加明显,但是 ViT 也存在如下问题:

  • 不采用超大的 JFT-300M 数据集进行预训练,则效果无法和 CNN 媲美,原因应该是 Transformer 天然的全局注意力计算,没有 CNN 这种 Inductive Bias 能力,需要大数据才能发挥其最大潜力。
  • ViT 无法直接适用于不同尺寸图片输入,因为 Patch 大小是固定的,当图片大小改变,此时序列长度就会改变,位置编码就无法直接适用了,ViT 解决办法是通过插值,这种做法一般会造成性能损失,需要通过 Finetune 模型来解决,有点麻烦。
  • 因为其直筒输出结构,无法直接应用于下游密集任务。


后面的文章对上述缺点采用了各种各样的改进,并提出了越来越先进的处理手段,推动了视觉 Transformer 的巨大进步。


3. 全局概述



由于内容非常多,为了更容易理解,我在拆解模块的基础上对每个模块进行分析,而不是对某篇文章进行概括。综述部分的分析流程按照结构图顺序描述,我将近期图像分类 Vision Transformer 发展按照 ViT 中是否包括自注意层模块来划分,包括:

  1. Attention-based, 这类算法是目前主流研究改进方向,包括了 Transformer 中最核心的自注意力模块。
  2. MLP-based, 这类算法不需要核心的自注意力模块,而是简单的通过 MLP 代替,也可以取得类似效果。
  3. ConvMixer-based,这类算既不需要自注意力模块,也不是单纯依靠 MLP,而是内部混合了部分 Conv 算子来实现类似功能。
  4. General architecture analysis,在这三类算法基础上也有很多学者在探讨整个 Transformer 架构,其站在一个更高的维度分析问题,不局限于是否包括自注意力模块,属于整体性分析。


在上述三个方向中,Attention-based 是目前改进最多,最热门的,也是本综述的核心。本文按照 3 个分类顺序依次分析,最后进行通用架构分析。通过 General architecture analysis 部分可以深化Attention-based、MLP-based 和 ConvMixer-based 三者的联系和区别。本文仅仅涉及 Attention-based 部分。

3.1 Attention-based


Attention-based 表示这类算法必然包括注意力模块,我们将按照广度优先顺序进行一次分析。

640.png


继 ViT 后,我们将其发展分成两条线路:训练策略和模型改进。

其中训练策略表示目前主流对 ViT 模型的训练改进方式,而模型改进则是对各个部件进行改进。

  • 训练策略包括两篇论文:DeiT 和 Token Labeling。两者提出的出发点一致,都是为了克服 ViT 需要 JFT-300M 大数据集进行预训练的缺点。DeiT 是通过引入蒸馏学习解决,而 Token Labeling 通过引入显著图然后施加密集监督解决。后续发展中大部分算法都是参考了 DeiT 的训练策略和超参设置,具有非常大的参考价值。


  • 模型改进方面,我将其分成了 6 个组件以及其他方面的改进,6 个组件包括:


  1. Token 模块,即如何将 image 转 token 以及 token 如何传递给下一个模块
  2. 位置编码模块
  3. 注意力模块,这里一般都是自注意力模块
  4. Fead Forward (FFN) 模块
  5. Norm 模块位置
  6. 分类预测头模块


下面按照训练策略和模型改进顺序分析。


3.1.1 训练策略


训练策略解决 ViT 需要大数据先预训练问题以及超参有待优化问题。


3.1.1.1 DeiT


如果说 ViT 开创了 Transformer 在视觉任务上面的先河,那么 DeiT 的出现则解决了 ViT 中最重要的问题:如果不采用超大的 JFT-300M 数据集进行预训练,则效果无法和 CNN 媲美。在单个节点 8 张 V100 且无需额外数据的情况下,用不到 3 天的时间训练所提的 ViT(86M 参数),在 ImageNet 上单尺度测试达到了 83.1% 的 top-1 准确率。
DeiT 核心是引入蒸馏手段加上更强的 Aug 和更优异的超参设置。其蒸馏的核心做法如下所示:

640.png

ViT 的 Class Token 是加到图片输入序列的前面,那么蒸馏 Token 可以插到输入序列的后面,当然插入到哪个位置其实无所谓,你也可以插入到 Class Token 后面,经过 Transformer 编码器输出的序列相比 ViT 也会多一个,然后额外的一个输出 Token 经过线性层输出相同类别通道,最后进行蒸馏学习。
对于蒸馏学习来说,做法通常有两个:

  • Soft 蒸馏,即学生模型和教师模型预测的 Softmax 概率分布值计算 KL Loss。
  • Hard 蒸馏,即教师模型预测的 Softmax 概率分布值中,值最大对应的类别作为标签,然后和学生模型预测的 Softmax 概率分布值计算 CE Loss。


蒸馏学习中,通常教师模型会选择一个比学生模型性能更强的且已经提前训练好的模型,教师模型不需要训练,通过蒸馏 loss 将教师模型知识以一种归纳偏置的方式转移给学生模型,从而达到提升学生模型性能的目的。因为引入了额外的蒸馏 Token,而且该 Token 训练任务也是分类,所以实际上 DeiT 在推理时,是将 Class Token 和 Distillation Token 的预测向量求平均,再转换为概率分布。
为了证明 Distillation Token 的有效性,而不是只由于多了一个 Token 或者说多了一个可学习参数导致的,作者还做了对比试验,不加 Distillation Token,而是再加一个 Class Token,相当于有两个分类头,两个 Token 独立且随机初始化,实验发现他们最终收敛后两个分类 Token 的相似度达到 0.999,并且性能更弱,这样证明了加入 Distillation Token 的意义。
通过大量实验,作者总结了如下结论:

  • 蒸馏做法确实有效,且 Hard 蒸馏方式效果会更好,泛化性能也不错
  • 使用 RegNet 作为教师网络可以取得更好的性能表现,也就是说相比 Transformer,采用卷积类型的教师网络效果会更好


除了上述蒸馏策略,还需要特别注意 DeiT 引入了非常多的 Aug 并且提供了一套更加优异的超参,这套参数也是后续大部分分类模型直接使用的训练参数,非常值得学习,如下所示:

640.png

总而言之, DeiT 算法非常优异,实验也非常多(建议去阅读下),最大贡献是通过蒸馏策略省掉了 ViT 需要 JFT-300M 数据集进行预训练这个步骤,并且提供了一套非常鲁棒且实用的超参配置,深深地影响了后续的大部分图像分类视觉  Transformer 模型。


3.1.1.2 Token Labeling


DeiT 不是唯一一个解决 ViT 需要大数据量问题的算法,典型的还有 Token Labeling,其在 ViT 的 Class Token 监督学习基础上,还对编码器输出的每个序列进行额外监督,相当于将图片分类任务转化为了多个输出 Token 识别问题,并为每个输入 Patch 的预测 Token 分配由算法自动生成的基于特定位置的监督信号,简要图如下所示:

640.png

从上图明显可以看出,相比 ViT 额外多了输出 Token 的监督过程,这些监督可以当做中间监督,监督信息是通过 EfficientNet 或者 NFNet ( F6 86.3% Top-1 accuracy) 这类高性能网络对训练图片提前生成的显著图,每个显著图维度是和类别一样长的 C 维,辅助 Loss 和分类一样也是 CE Loss。当然最终实验结果表明性能比 DeiT 更优异,而且由于这种密集监督任务,对于下游密集预测任务泛化性也更好。
在此基础上 DeiT 已经证明通过对 ViT 引入更多的强 Aug 可以提升性能,例如引入 CutMix,但是本文的做法无法直接简单增加 CutMix,为此作者还专门设计了一个 MixToken,大概做法是在 Pathc Embedding 后,对 Token 进行了相应的 CutMix 操作。性能表如下所示:

640.png

LV-ViT 即为本文所提模型。相比 DeiT,作者认为本文做法更加优异,体现在:

  • 不需要额外的教师模型,是一个更加廉价的做法。
  • 相比于单向量监督,以密集的形式监督可以帮助训练模型轻松发现目标物体,提高识别准确率,实验也证明了对下游密集预测任务(例如语义分割)更友好。


下表是对训练技术的详细分析:

640.png

简而言之,Token Labeling 的核心做法是通过引入额外的显著图来监督每个 patch 输出的预测 token,虽然不需要教师模型,但是依然需要利用更优异的模型对所有训练图片生成显著图。


3.1.2 模型改进


在 DeiT 提出后,后续基于它提出了大量的改进模型,涉及到 ViT 的方方面面。前面说过 ViT 模型主要涉及到的模块包括:

  1. Token 模块,即如何将 image 转 token 以及 token 如何传递给下一个模块
  2. 位置编码模块
  3. 注意力模块,这里一般都是自注意力模块
  4. Fead Forward (FFN) 模块
  5. Norm 模块位置
  6. 分类预测模块


3.1.2.1 Token 模块


Token 模块包括两个部分:

  1. Image to Token 模块即如何将图片转化为 Token,一般来说分成有重叠和无重叠的 Patch Embedding 模块
  2. Token to Token 模块即如何在多个 Transformer 编码器间传递 Token,通常也可以分成固定窗口 Token 化过程和动态窗口 Token 化过程两个


下面是完整结构图:

640.png


3.1.2.1.1 Image to Token 模块


首先需要明确:Patch Embedding 通常包括图片窗口切分和线性嵌入两个模块,本小结主要是说图片窗口切分方式,而具体实现不重要,常用的 2 种实现包括 nn.Conv 和 nn.Unfold,只要设置其 kernel 和 stride 值相同,则为非重叠 Patch Embedding,如果 stride 小于 kernel 则为重叠  Patch Embedding。


(1)  非重叠 Patch Embedding

ViT 和目前主流模型例如 PVT 和 Swin Transformer 等都是采用了非重叠 Patch Embedding,其简要代码为:

# 非重叠只需要设置Conv kernel_size 和 stride 相同即可
_conv_cfg = dict(
    type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)
x = self.projection(x).flatten(2).transpose(1, 2)

通过设置 16x16 的 kernel 和 stride 可以将图片在空间维度进行切割,每个 patch 长度为 16x16x3,然后将每个 Patch 重排拉伸为一维向量后,经过线性层维度变换,输出 shape 为 (B, Num_Seq, Dim)。
在 TNT 中作者提出了一种更精细的非重叠 Patch Embedding 模块,如下图所示:

640.png

他的基本观点是自然图像的复杂度相较于自然文本更高,细节和颜色信息更丰富,而 ViT 的非重叠 Patch Embedding 做法过于粗糙,因为后续自注意力计算都是在不同 Patch 间,这会导致 Patch 内部的局部自注意力信息没有充分提取,而这些信息在图像中也是包含了不同的尺度和位置的物体特征,是非常关键的。故我们不仅要考虑 Patch 间自注意力,还要考虑 Patch 内自注意力,为此作者在 外层 Transformer 中又嵌入了一个内层 Transformer,相应的非重叠 Patch Embedding 也分成两步:整图的非重叠 Patch Embedding 过程和 Patch 内部更细粒度的非重叠 Patch Embedding 过程。
通过上图大概可以看出其具体做法,内部相当于有两个 Transformer,第一个 Transformer (Outer Transformer )和 ViT 完全一样,处理句子 Sentences 信息即图片 Patch 级别信息,第二个 Transformer (Inner Transformer,也需要额外加上 Inner Transformer 所需要的位置编码) 处理更细粒度的 Words 信息即图片 Patch 内再切分为 Patch,为了能够将两个分支信息融合,内部会将 Inner Transformer 信息和 Outer Transformer  相加。将上述 Transformer block 嵌入到 PVT 模型中验证了其对下游任务的适用性,通过进一步的可视化分析侧面验证了分类任务上 TNT 相比 DeiT 的优异性。

640.png


(2)  重叠 Patch Embedding


在常规的 CNN 网络中一般都是采用重叠计算方式,为此是否采用重叠  Patch Embedding 会得到更好的性能?直接将非重叠  Patch Embedding 通过修改 Unfold 或者 Conv 参数来实现重叠  Patch Embedding 功能的典型算法包括 T2T-ViT  和 PVTv2,这两个算法的出发点都是非重叠 Patch Embedding 可以加强图片 Patch 之间的连续性,不至于出现信息断层,性能应该会比重叠 Patch Embedding 高。PVTv2 内部采用 Conv 实现,而 T2T ViT 是通过 Unfold 方式实现(论文中称为 soft split)。

640.png

前面说过 CNN 网络中一般都是采用重叠计算方式,那么是否可以用 ResNet Stem 替换非重叠  Patch Embedding过程,性能是否会更好?
在 Early Convolutions Help Transformers See Better 论文中,作者进行了深度分析,虽然作者只是简单的将图片 Token 化的 Patch Embedding 替换为 ResNet Conv Stem,但是作者是从优化稳定性角度入手,通过大量的实验验证上述做法的有效性。作者指出 Patch Embedding 之所以不稳定,是因为该模块是用一个大型卷积核以及步长等于卷积核的卷积层来实现的,往往这个卷积核大小为 16*16,这样的卷积核参数量很大,而且随机性很高,从某种程度上造成了 Transformer 的不稳定,如果用多个小的卷积来代替则可以有效缓解。结构如下所示:

640.png

考虑了和 ViT 公平对比,新引入的 Conv Stem 计算量约等于一个 transformer block,故后续仅仅需要 L-1 个 transformer block。作者通过大量分析实验得到一些经验性看法:
(a) ViT 这类算法对 lr 和 wd 超参的选择非常敏感,而替换 Stem 后会鲁棒很多
(b) ViT 这类算法收敛比较慢,而本算法会快很多,例如都在 100 epoch 处本文性能远优于 ViT

640.png

ViT_p 即为 Patch Embedding 模式,ViT_c 即为 Conv Stem 模式,可以看出在不同 flops 下模型收敛速度都是 ViT_c 快于 ViT_p,虽然到 400 epoch 时候性能都非常接近。
(c) ViT 这类算法只能采用 AdamW 训练,而本文更加通用,采用 SGD 后性能没有显著下降。
众所周知,ViT 类模型都只能用 AdamW 训练,其占据显存是 SGD 的 3 倍,所以一般在 CNN 网络中都是用过 SGD 模型,性能通常不错,而通过替换  Patch Embedding  后也可以用 SGD 训练了。

640.png

(d) 仅仅采用 ImageNet 训练 ViT 性能难以超越 CNN,而本文可以进一步提升 ViT 性能。
与上述论文持相同观点的也包括  ResT 、Token Learner、CSWin Transformer 等算法,他们都采用了完全相同的做法。更进一步在 PS-ViT 中为了能够方便后续的渐进采样模块稳定提取更好的特征点,作者在 Image to Token 模块中不仅仅引入了 ResNet 的 Conv Stem 模块,还在后面再使用了 ResNet 第一个 stage 的前两个残差 block,在 Token to Token 模块中会详细说明 PS-ViT。
在 CeiT 中作者出发点是 CNN 中的诸多特性已经被证明是很成功的,纯粹的 Transformer 需要大量的数据、额外的监督才能达到和 CNN 相同的精度,出现这种问题的原因可能是 NLP 中的 Transformer  直接搬到图像任务中可能不是最合适的,应该考虑部分引入 CNN 来增强 Transformer。具体来说,在图片转 Token 方案中提出 Image-to-Tokens (I2T) 模块,不再是从图片中直接进行 Patch Emeding ,而是对 CNN 和 Pool 层所提取的底层特征进行 Patch Embedding,借助图像特征会比直接使用图片像素更好。

640.png

上图的上半部分是 ViT 的 Patch Embedding 过程,下图是 CeiT 所提出的做法,核心就是引入卷积操作提取底层特征,然后在底层特征上进行 Patch Embedding 操作。
既然采用 Conv Stem 可以解决很多问题,那么理论上经过精心设计的 Conv 结构也必然是有效的,例如 ViTAE 中就采用了空洞卷积做法,本质上是希望能够利用卷积提供多尺度上下文信息,这有助于后续模块信息提取,如下图所示:

640.png

对图片或者特征图应用多个不同空洞率的卷积提取信息后,进行拼接和 GeLU 激活函数后,直接拉伸为一维向量,从而转换为序列,并且由于空洞卷积可以实现下采样功能,故也可以有效地减少后续注意力模块计算量。


3.1.2.1.2 Token to Token 模块


大部分模型的 Token to Token 方案和 Image to Token 做法相同,但是也有些算法进行了相应改造。经过整理,我们将其分成两种做法:

  1. 固定窗口 Token 化
  2. 动态窗口 Token 化


固定窗口是指 Token 化过程是固定或者预定义的规则,典型的重叠和非重叠  Patch Embedding  就是固定窗口,因为其窗口划分都是提前订好的规则,不会随着输入图片的不同而不同,而动态窗口是指窗口划分和输入图片语义相关,不同图片不一样,是一个动态过程。

(1) 固定窗口 Token 化
这个做法通常和 Image to Token 模块完全一样,也可以分成非重叠  Patch Embedding  和重叠  Patch Embedding,大部分算法都属于这一类,例如 PVT、Swin Transformer 等。
(2) 动态窗口 Token 化
动态窗口 Token 化过程典型代表是 PS-ViT 和 TokenLearner。
前面说过,Vision Transformer with Progressive Sampling (PS-ViT) 中为了方便后续的渐进采样模块能够稳定提取更好的特征点,在 Image to Token 模块中不仅仅引入了 ResNet 的 Conv Stem 模块,还在后面再使用了 ResNet 第一个 stage 的前两个残差 block。在特征图 F 后,作者在 Token to Token 环节引入了一个渐进式采样模块,其出发点是 ViT 采用固定窗口划分机制,然后对每个窗口进行 Token 化,这种做法首先不够灵活,而且因为图片本身就是密集像素,冗余度非常高,采用固定划分方法对于分类来说可能就某几个窗口内的 Token 实际上才是有意义的,假设物体居中那么物体四周的 Token 可能是没有作用的,只会增加无效计算而已。基于此作者设计一个自适应采样的 Token 机制,不再是固定的窗口采样,而是先初始化固定采样点,如下图红色点所示,然后通过 refine 机制不断调整这些采样点位置,最终得到的采样点所对应的 Token 就是最有代表力的。其完整分类网络结构图如下所示:

640.png

得到特征图 F 后,经过渐进采样模块,不断 refine 采样点,最终输出和采样点个数个序列,将该序列作为 ViT 模型的输入即可。简单来看渐进采样模块起到了 Token to Token 作用。其中的渐进采样模块结构图如下所示:

640.png

详细计算过程如下:

  1. 首先图片经过 ResNet Conv Stem + ResNet 第一个 stage 的前两个残差块进行特征提取,得到 F。
  2. 在特征图或者原图上先设置初始均匀固定间隔采样点 pt,上图是 9 个采样点,表示最终序列长度是 9。
  3. 利用 pt 值对 F 进行采样,提取对应位置的特征向量,加上位置编码输入到编码器中,输出 T_t。
  4. 将 T_t 经过一个 FC 层生成 offset,将该 offset 和初始位置 pt 相加就可以得到 refine 后的 p_t+1。
  5. 将 3-4 步骤重复 N 次,下一个采样模块的输入包括 refine 后的 pt、特征图 F 和上一个采样模块的输出 T,三者相加。
  6. 经过 N 次 refine 后,将该 token 序列拼接上 class token,然后再经过 M 个编码器模块。
  7. 最后对 class token 对应位置输出 token 进行分类训练即可。


可以发现,和 ViT 的主要差异就在于其采样点不是固定的均匀间隔,而是基于语义图自适应,从而能够在减少计算量的前提下进一步提升性能。PS-ViT 在 top-1 精度方面比普通 ViT 高 3.8%,参数减少约 4 倍,FLOP 减少约 10 倍,性能比较优异。
基于类似出发点,TokenLearner 提出可以基于空间注意力自适应地学习出更具有代表性的 token,从而可以将 ViT 的 1024 个 token 缩减到 8-16 个 token,计算量减少了一倍,性能依然可以保持一致,这也侧面说明了 ViT 所采样的固定窗口 token 有大量冗余,徒增计算量而已。其核心示意图如下所示:

640.png

假设想仅仅采样出 8 个 token,首先采用 Conv Stem 提取图片特征,然后分别输入到 8 个空间注意力分支中,空间注意力分支首先会应用一系列卷积生成空间 attention 图,然后逐点和输入特征相乘进行特征加权,最后通过空间全局 pool 层生成 1x1xC 的 Token,这样就将 HXWXC 的特征图转换为了 8 个通道为 C 的 Token。
为了进一步提高信息,作者还额外提出一个 TokenFuser 模块,加强 Token 和 Token 之间的联系以及恢复空间结构,整个分类网络的结构如下所示:(a) 为不包括 TokenFuser 的改进 ViT 结构,(b) 为包括 TokenFuser 的改进 ViT 结构。

640.png

从上述结构可以发现, TokenLearner 模块起到了自适应提取更具语义信息的 Token,并且还能够极大地减少计算量,而 TokenFuser 可以起到加强 Token 和 Token 之间的联系以及恢复空间结构的功能,TokenLearner+Transformer+ TokenFuser 构成 Bottleneck 结构。其中 TokenFuser 示意图如下所示:

640.png

其接收两个输入,一个是 TokenLearner 前的保持了空间信息的 1024 个 token 特征,一个是 TokenLearner 后经过自适应采样的 8 个 token 特征,然后以注意力模式两者进行乘加操作,融合特征以及恢复空间结构。
作者的分类实验依然采用了 JFT-300M 数据集进行预训练,然后在 ImageNet 1k上面微调,也就是说和最原始的 ViT 进行比较。

640.png

TokenFuser 也进行了相应的对比实验。

640.png

at 6 表示 TokenLearner 插入到第 6 个 Transformer Encoder 后。


文章来源:【OpenMMLab

2022-01-26 18:00


目录
相关文章
|
机器学习/深度学习 编解码 并行计算
论文阅读笔记 | Transformer系列——CSWin Transformer
论文阅读笔记 | Transformer系列——CSWin Transformer
710 0
论文阅读笔记 | Transformer系列——CSWin Transformer
|
3月前
|
机器学习/深度学习 PyTorch 语音技术
【文献学习】Conformer: Convolution-augmented Transformer for Speech Recognition
文章介绍了Conformer模型,这是一种结合了Transformer的自注意力机制和CNN卷积模块的混合模型,旨在提高语音识别任务的性能,通过自注意力捕捉全局上下文信息,同时利用卷积模块有效捕获局部特征。
74 0
|
机器学习/深度学习 自然语言处理 并行计算
【Transformer系列(3)】 《Attention Is All You Need》论文超详细解读(翻译+精读)
【Transformer系列(3)】 《Attention Is All You Need》论文超详细解读(翻译+精读)
1404 0
【Transformer系列(3)】 《Attention Is All You Need》论文超详细解读(翻译+精读)
|
机器学习/深度学习 算法 数据可视化
深度学习论文阅读目标检测篇(一):R-CNN《Rich feature hierarchies for accurate object detection and semantic...》
 过去几年,在经典数据集PASCAL上,物体检测的效果已经达到 一个稳定水平。效果最好的方法是融合了多种低维图像特征和高维上 下文环境的复杂集成系统。在这篇论文里,我们提出了一种简单并且 可扩展的检测算法,可以在VOC2012最好结果的基础上将mAP值提 高30%以上——达到了53.3%。
161 0
深度学习论文阅读目标检测篇(一):R-CNN《Rich feature hierarchies for accurate object detection and semantic...》
|
计算机视觉
论文阅读笔记 | Transformer系列——Transformer in Transformer
论文阅读笔记 | Transformer系列——Transformer in Transformer
295 0
论文阅读笔记 | Transformer系列——Transformer in Transformer
|
机器学习/深度学习 编解码 自然语言处理
论文阅读笔记 | Transformer系列——Swin Transformer
论文阅读笔记 | Transformer系列——Swin Transformer
1198 0
论文阅读笔记 | Transformer系列——Swin Transformer
|
机器学习/深度学习 算法 大数据
Vision Transformer 必读系列之图像分类综述(三): MLP、ConvMixer 和架构分析(下)
在 Vision Transformer 大行其道碾压万物的同时,也有人在尝试非注意力的 Transformer 架构(如果没有注意力模块,那还能称为 Transformer 吗)。这是一个好的现象,总有人要去开拓新方向。相比 Attention-based 结构,MLP-based 顾名思义就是不需要注意力了,将 Transformer 内部的注意力计算模块简单替换为 MLP 全连接结构,也可以达到同样性能。典型代表是 MLP-Mixer 和后续的 ResMLP。
1156 0
Vision Transformer 必读系列之图像分类综述(三): MLP、ConvMixer 和架构分析(下)
|
机器学习/深度学习 算法 数据挖掘
【vision transformer】LETR论文解读及代码实战(一)
【vision transformer】LETR论文解读及代码实战
184 0
|
机器学习/深度学习 自然语言处理 数据可视化
阿里提出QuadTree Transformer | 最轻、最强的Vision Transformer Backbone(一)
阿里提出QuadTree Transformer | 最轻、最强的Vision Transformer Backbone(一)
244 0
|
计算机视觉
阿里提出QuadTree Transformer | 最轻、最强的Vision Transformer Backbone(二)
阿里提出QuadTree Transformer | 最轻、最强的Vision Transformer Backbone(二)
110 0