Vision Transformer 必读系列之图像分类综述(三): MLP、ConvMixer 和架构分析(上)

本文涉及的产品
简介: 在 Vision Transformer 大行其道碾压万物的同时,也有人在尝试非注意力的 Transformer 架构(如果没有注意力模块,那还能称为 Transformer 吗)。这是一个好的现象,总有人要去开拓新方向。相比 Attention-based 结构,MLP-based 顾名思义就是不需要注意力了,将 Transformer 内部的注意力计算模块简单替换为 MLP 全连接结构,也可以达到同样性能。典型代表是 MLP-Mixer 和后续的 ResMLP。


1. MLP-based



640.png

在 Vision Transformer 大行其道碾压万物的同时,也有人在尝试非注意力的 Transformer 架构(如果没有注意力模块,那还能称为 Transformer 吗)。这是一个好的现象,总有人要去开拓新方向。相比 Attention-based 结构,MLP-based 顾名思义就是不需要注意力了,将 Transformer 内部的注意力计算模块简单替换为 MLP 全连接结构,也可以达到同样性能。典型代表是 MLP-Mixer 和后续的 ResMLP


1.1  MLP-Mixer


虽然 CNN 的卷积操作和 Vision Transformer 注意力在各个架构中都足以获得良好的性能,但它们都不是必需的,如果替换为本文设计的 MLP 结构依然可以取得一致性性能。

640.png



  • 将图片切分成不重叠的 patch 块,
    将patch 输入到 Pre-patch FC 层中,对每个 patch 进行线性映射,这两个步骤实际上就是 patch embeding 过程,假设输出是 (Patch, C),不同的颜色块代表不同的 patch。
  • 将上述 (Patch, C) 输入到 N 个 Mixer Layer 中进行特征提取。
  • 最后输出序列经过 global average pooling 聚合特征,然后接上 FC 层进行分类即可。


Mixer Layer 中整体结构和 Transformer 编码器类似,只不过内部不存在自注意力模块,而是使用两个不同类型的 MLP 代替,其分别是 channel-mixing MLPs 和 token-mixing MLPs。channel-mixing MLPs 用于在通道 C 方向特征混合,从上图中的 Channels (每个通道颜色一样) 可以明显看出其做法,而 token-mixing MLPs 用于在不同 patch 块间进行特征混合,其作用于 patch 方向。
在极端情况下,上述两个 Mixer Layer 可以看出使用 1×1 卷积进行通道混合,并使用全感受野的和参数共享的单通道深度卷积进行 patch 混合。反之则不然,因为典型的 CNN 不是 Mixer 的特例。此外卷积比 MLP 中的普通矩阵乘法更复杂,因为它需要对矩阵乘法和/或专门的实现进行额外的成本降低。代码如下所示,非常简洁。

# 代码是先进行 token mixing 再进行 channel mixing
class MixerBlock(nn.Module):
  """Mixer block layer."""
  tokens_mlp_dim: int
  channels_mlp_dim: int
  @nn.compact
  def __call__(self, x):
    # (b, patch, c)
    y = nn.LayerNorm()(x)
    # 交互为 (b, c, patch)
    y = jnp.swapaxes(y, 1, 2)
    # MlpBlock 作用于 patch 维度,实现 token mixing
    y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y)
    # 交换回来
    y = jnp.swapaxes(y, 1, 2)
    x = x + y
    y = nn.LayerNorm()(x)
    # MlpBlock 作用于 C 维度,实现 channel mixing
    return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y)
class MlpBlock(nn.Module):
  mlp_dim: int
  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.mlp_dim)(x)
    y = nn.gelu(y)
    return nn.Dense(x.shape[-1])(y)  

结果如下所示,可以看出性能和 ViT 非常接近。

640.png


1.2 ResMLP


几乎在同时,ResMLP 也沿着这条思路也进行了一些尝试。示意图如下所示:

640.png

从上图来看,几乎和 MLP-Mixer 一样,最核心的两个 MLP 层也是分成跨 patch 交互 MLP 层和跨通道 MLP 层。最后输出也是采用 avg pool 进行聚合后分类。
同时作者观察到如下现象:

  • 当使用与 DeiT 和 CaiT 相同的训练方案时,ResMLP 的训练比 ViTs 更稳定,不再需要 BatchNorm、GroupNorm 或者 Layer Norm 等归一化层。作者推测这种稳定性来自于用线性层代替自注意力。
  • 使用线性层的另一个优点是仍然可以可视化 patch embeding 之间的相互作用,揭示了类似于和 CNN 一样的学习特性即前面层抽取底层特征,后续层抽取高维语义特征。


MLP-Based 类算法相比 ViT 算法,有如下好处:

  • 不再需要自注意力模块。
  • 不再需要位置编码。
  • 不再需要额外的 class token。
  • 不再需要 Batch 等 Norm 统计算子,只需要引入可学习的 affine 层即可。


通过 MLP-Mixer 和 ResMLP 大家逐渐意识到 ViT 成功的关键可能并不是注意力机制,这也间接说明了目前大家对视觉 Transformer 架构理解度还是不够,还有很多研究空间。


1.3 CycleMLP


众所周知,MLP 一个非常大的弊端是无法自适应图片尺寸,这对下游密集预测任务不友好,MLP-Mixer 和 ResMLP 都存在无法方便用于下游任务的问题,基于这个缺点,CycleMLP 对 MLP 引入周期采样功能,使其具备了自适应图片尺寸的功能,大大提升了 MLP-based 类算法的实用性。其核心做法如下所示:

640.png

将 FC 作用于 Channel 通道即 Channel FC 层可以实现自适应图片尺寸功能,因为其信息聚合维度是 C,而这个维度本身是不会随着图片尺寸而改变,将 FC 作用于 Spatial 维度即 Spatial FC 层无法实现自适应图片尺寸功能,HW 维度会随着图片大小而改变。MLP-Mixer 和 ResMLP 为了聚合信息在自注意力层都会包含  Spatial FC 层和 Channel FC 层。
Spatial FC 层的主要作用是进行 patch 或者序列之间的信息交互,比较关键,无法简单的移除。可以从下面对比实验看下:

640.png

为了能够在移除 Spatial FC 层但依然保持 patch 或者序列之间的信息交互能力,论文提出一种循环采样 FC 层 Cycle FC,其属于局部窗口计算机制,周期性地在空间维度进行有序采样,和可变形卷积做法非常类似,实际上代码确实是直接采用可变形卷积实现的,由于比较清晰就不再分析计算过程了。作者绘制了非常详细的可视化图:

640.png

不再将自注意力层输入看出序列 (N,C) 格式,而是视为图片特征 (H, W, C) 格式,此时 Cycle FC 采样的空间维度可以选择 H 或者 W,如果设置 S_H=3 S_W=1,则如图 c 所示,采样方向是 H 方向,W 方向不进行聚合。而  Cycle MLP 模块(下图的 Spatial Proj 模块)实际上由 1x7 的 Cycle FC  层、 7x1 的 Cycle FC  层和 Channel FC 并联,然后相加构成。

640.png

代码如下所示:

class CycleMLP(nn.Module):
    def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
        # W 方向采样
        self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0)
        # H 方向采样
        self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0)
        # 通道方向
        self.reweight = Mlp(dim, dim // 4, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, H, W, C = x.shape
        h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        c = self.mlp_c(x)
        a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
        a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
        x = h * a[0] + w * a[1] + c * a[2]
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

作者也对采样过程进行深入分析,对比了另外 2 种采样策略。

  • 随机采样,并设置了 3 次不同的种子分别进行实现
  • 空洞 stepsize 采样


空洞 stepsize 采样即每隔 stepsize 个位置再进行采样。

640.png

实验表明,空洞采样效果比随机采样好,但是不如循环采样。
因为 FC 也可以用 Conv 层代替,作者又和 Conv 1x3 和 3x1 核进行比较,Conv 模式相比 CycleMLP 属于密集采样模式,效果居然比稀疏的 CycleMLP 差,作者分析原因是:密集采样模式会引入额外的参数量,为了快速比较将 epoch 相应的缩短到 100 epoch,在这种设置下,可能密集采样会引入额外的冗余参数,不利于优化。640.jpg

在构建好 CycleMLP 后,将带 CycleMLP FC 层自注意力模块应用于 PVT 和 Swin 中即可无缝的应用于各种密集下游任务,其详细的结构图如下所示,所构建的网络性能优于目前主流的 Swin Transformer。640.png



文章来源:【OpenMMLab

2022-01-27 18:24



相关实践学习
基于函数计算一键部署掌上游戏机
本场景介绍如何使用阿里云计算服务命令快速搭建一个掌上游戏机。
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
目录
相关文章
|
2天前
|
Dubbo Cloud Native 网络协议
【Dubbo3技术专题】「服务架构体系」第一章之Dubbo3新特性要点之RPC协议分析介绍
【Dubbo3技术专题】「服务架构体系」第一章之Dubbo3新特性要点之RPC协议分析介绍
42 1
|
2天前
|
设计模式 安全 Java
【分布式技术专题】「Tomcat技术专题」 探索Tomcat技术架构设计模式的奥秘(Server和Service组件原理分析)
【分布式技术专题】「Tomcat技术专题」 探索Tomcat技术架构设计模式的奥秘(Server和Service组件原理分析)
39 0
|
2天前
|
机器学习/深度学习 人工智能 自然语言处理
一文介绍CNN/RNN/GAN/Transformer等架构 !!
一文介绍CNN/RNN/GAN/Transformer等架构 !!
22 4
|
2天前
|
机器学习/深度学习 自然语言处理 并行计算
一文搞懂Transformer架构的三种注意力机制
一文搞懂Transformer架构的三种注意力机制
32 1
|
2天前
|
机器学习/深度学习 自然语言处理
【大模型】在大语言模型的架构中,Transformer有何作用?
【5月更文挑战第5天】【大模型】在大语言模型的架构中,Transformer有何作用?
|
2天前
|
存储 运维 监控
|
2天前
|
机器学习/深度学习 人工智能 自然语言处理
革命新架构掀翻Transformer!无限上下文处理,2万亿token碾压Llama 2
【4月更文挑战第28天】清华大学研究团队提出了Megalodon,一种针对长序列数据优化的Transformer模型。为解决Transformer的计算复杂度和上下文限制,Megalodon采用了CEMA改进注意力机制,降低计算量和内存需求;引入时间步长归一化层增强稳定性;使用归一化注意力机制提升注意力分配;并借助预归一化与双跳残差配置加速模型收敛。在与Llama 2的对比实验中,Megalodon在70亿参数和2万亿训练token规模下展现出更优性能。论文链接:https://arxiv.org/abs/2404.08801
28 2
|
2天前
|
资源调度 分布式计算 Hadoop
【Hadoop Yarn】YARN 基础架构分析
【4月更文挑战第7天】【Hadoop Yarn】YARN 基础架构分析
|
2天前
|
存储 负载均衡 NoSQL
【分布式技术架构】「Tomcat技术专题」 探索Tomcat集群架构原理和开发分析指南
【分布式技术架构】「Tomcat技术专题」 探索Tomcat集群架构原理和开发分析指南
54 1
|
2天前
|
存储 Java 应用服务中间件
【分布式技术专题】「架构实践于案例分析」盘点互联网应用服务中常用分布式事务(刚性事务和柔性事务)的原理和方案
【分布式技术专题】「架构实践于案例分析」盘点互联网应用服务中常用分布式事务(刚性事务和柔性事务)的原理和方案
66 0