1. MLP-based
在 Vision Transformer 大行其道碾压万物的同时,也有人在尝试非注意力的 Transformer 架构(如果没有注意力模块,那还能称为 Transformer 吗)。这是一个好的现象,总有人要去开拓新方向。相比 Attention-based 结构,MLP-based 顾名思义就是不需要注意力了,将 Transformer 内部的注意力计算模块简单替换为 MLP 全连接结构,也可以达到同样性能。典型代表是 MLP-Mixer 和后续的 ResMLP。
1.1 MLP-Mixer
虽然 CNN 的卷积操作和 Vision Transformer 注意力在各个架构中都足以获得良好的性能,但它们都不是必需的,如果替换为本文设计的 MLP 结构依然可以取得一致性性能。
将图片切分成不重叠的 patch 块,将patch 输入到 Pre-patch FC 层中,对每个 patch 进行线性映射,这两个步骤实际上就是 patch embeding 过程,假设输出是 (Patch, C),不同的颜色块代表不同的 patch。- 将上述 (Patch, C) 输入到 N 个 Mixer Layer 中进行特征提取。
- 最后输出序列经过 global average pooling 聚合特征,然后接上 FC 层进行分类即可。
Mixer Layer 中整体结构和 Transformer 编码器类似,只不过内部不存在自注意力模块,而是使用两个不同类型的 MLP 代替,其分别是 channel-mixing MLPs 和 token-mixing MLPs。channel-mixing MLPs 用于在通道 C 方向特征混合,从上图中的 Channels (每个通道颜色一样) 可以明显看出其做法,而 token-mixing MLPs 用于在不同 patch 块间进行特征混合,其作用于 patch 方向。
在极端情况下,上述两个 Mixer Layer 可以看出使用 1×1 卷积进行通道混合,并使用全感受野的和参数共享的单通道深度卷积进行 patch 混合。反之则不然,因为典型的 CNN 不是 Mixer 的特例。此外卷积比 MLP 中的普通矩阵乘法更复杂,因为它需要对矩阵乘法和/或专门的实现进行额外的成本降低。代码如下所示,非常简洁。
# 代码是先进行 token mixing 再进行 channel mixing class MixerBlock(nn.Module): """Mixer block layer.""" tokens_mlp_dim: int channels_mlp_dim: int @nn.compact def __call__(self, x): # (b, patch, c) y = nn.LayerNorm()(x) # 交互为 (b, c, patch) y = jnp.swapaxes(y, 1, 2) # MlpBlock 作用于 patch 维度,实现 token mixing y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) # 交换回来 y = jnp.swapaxes(y, 1, 2) x = x + y y = nn.LayerNorm()(x) # MlpBlock 作用于 C 维度,实现 channel mixing return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y) class MlpBlock(nn.Module): mlp_dim: int @nn.compact def __call__(self, x): y = nn.Dense(self.mlp_dim)(x) y = nn.gelu(y) return nn.Dense(x.shape[-1])(y)
结果如下所示,可以看出性能和 ViT 非常接近。
1.2 ResMLP
几乎在同时,ResMLP 也沿着这条思路也进行了一些尝试。示意图如下所示:
从上图来看,几乎和 MLP-Mixer 一样,最核心的两个 MLP 层也是分成跨 patch 交互 MLP 层和跨通道 MLP 层。最后输出也是采用 avg pool 进行聚合后分类。
同时作者观察到如下现象:
- 当使用与 DeiT 和 CaiT 相同的训练方案时,ResMLP 的训练比 ViTs 更稳定,不再需要 BatchNorm、GroupNorm 或者 Layer Norm 等归一化层。作者推测这种稳定性来自于用线性层代替自注意力。
- 使用线性层的另一个优点是仍然可以可视化 patch embeding 之间的相互作用,揭示了类似于和 CNN 一样的学习特性即前面层抽取底层特征,后续层抽取高维语义特征。
MLP-Based 类算法相比 ViT 算法,有如下好处:
- 不再需要自注意力模块。
- 不再需要位置编码。
- 不再需要额外的 class token。
- 不再需要 Batch 等 Norm 统计算子,只需要引入可学习的 affine 层即可。
通过 MLP-Mixer 和 ResMLP 大家逐渐意识到 ViT 成功的关键可能并不是注意力机制,这也间接说明了目前大家对视觉 Transformer 架构理解度还是不够,还有很多研究空间。
1.3 CycleMLP
众所周知,MLP 一个非常大的弊端是无法自适应图片尺寸,这对下游密集预测任务不友好,MLP-Mixer 和 ResMLP 都存在无法方便用于下游任务的问题,基于这个缺点,CycleMLP 对 MLP 引入周期采样功能,使其具备了自适应图片尺寸的功能,大大提升了 MLP-based 类算法的实用性。其核心做法如下所示:
将 FC 作用于 Channel 通道即 Channel FC 层可以实现自适应图片尺寸功能,因为其信息聚合维度是 C,而这个维度本身是不会随着图片尺寸而改变,将 FC 作用于 Spatial 维度即 Spatial FC 层无法实现自适应图片尺寸功能,HW 维度会随着图片大小而改变。MLP-Mixer 和 ResMLP 为了聚合信息在自注意力层都会包含 Spatial FC 层和 Channel FC 层。
Spatial FC 层的主要作用是进行 patch 或者序列之间的信息交互,比较关键,无法简单的移除。可以从下面对比实验看下:
为了能够在移除 Spatial FC 层但依然保持 patch 或者序列之间的信息交互能力,论文提出一种循环采样 FC 层 Cycle FC,其属于局部窗口计算机制,周期性地在空间维度进行有序采样,和可变形卷积做法非常类似,实际上代码确实是直接采用可变形卷积实现的,由于比较清晰就不再分析计算过程了。作者绘制了非常详细的可视化图:
不再将自注意力层输入看出序列 (N,C) 格式,而是视为图片特征 (H, W, C) 格式,此时 Cycle FC 采样的空间维度可以选择 H 或者 W,如果设置 S_H=3 S_W=1,则如图 c 所示,采样方向是 H 方向,W 方向不进行聚合。而 Cycle MLP 模块(下图的 Spatial Proj 模块)实际上由 1x7 的 Cycle FC 层、 7x1 的 Cycle FC 层和 Channel FC 并联,然后相加构成。
代码如下所示:
class CycleMLP(nn.Module): def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias) # W 方向采样 self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0) # H 方向采样 self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0) # 通道方向 self.reweight = Mlp(dim, dim // 4, dim * 3) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, H, W, C = x.shape h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) c = self.mlp_c(x) a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2) a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2) x = h * a[0] + w * a[1] + c * a[2] x = self.proj(x) x = self.proj_drop(x) return x
作者也对采样过程进行深入分析,对比了另外 2 种采样策略。
- 随机采样,并设置了 3 次不同的种子分别进行实现
- 空洞 stepsize 采样
空洞 stepsize 采样即每隔 stepsize 个位置再进行采样。
实验表明,空洞采样效果比随机采样好,但是不如循环采样。
因为 FC 也可以用 Conv 层代替,作者又和 Conv 1x3 和 3x1 核进行比较,Conv 模式相比 CycleMLP 属于密集采样模式,效果居然比稀疏的 CycleMLP 差,作者分析原因是:密集采样模式会引入额外的参数量,为了快速比较将 epoch 相应的缩短到 100 epoch,在这种设置下,可能密集采样会引入额外的冗余参数,不利于优化。
在构建好 CycleMLP 后,将带 CycleMLP FC 层自注意力模块应用于 PVT 和 Swin 中即可无缝的应用于各种密集下游任务,其详细的结构图如下所示,所构建的网络性能优于目前主流的 Swin Transformer。
文章来源:【OpenMMLab】
2022-01-27 18:24