论文阅读笔记 | MLP系列——MLP部分汇总(MLP-Mixer、S2-MLP、AS-MLP、ViP、S2-MLPv2)

简介: 论文阅读笔记 | MLP系列——MLP部分汇总(MLP-Mixer、S2-MLP、AS-MLP、ViP、S2-MLPv2)

1. MLP-Mixer


image.png

详细笔记见:论文阅读笔记 | MLP系列——MLP-Mixer


2. S2-MLP


出发点:过拟合的角度

MLP-Mixer只在比较大的数据集上可以取得和 CNN 以及 Transformer 结构相当甚至更好的性能。然而,单单在 ImageNet 1k 或者 ImageNet 21K 上训练测试,其性能其实并不算太好。因为虽然 MLP-Mixer 增加了学习的自由性,没有给予局部性啊这些的约束,但是正因如此才更容易过拟合。所以只有当它在超大规模数据量的训练下才可能变得普适。为此,我们实际上还是得给一些约束或者指导,以帮助模型在中小规模数据上训练得更好。


  • S2-MLP Block结构:

image.png

空间位移操作其实就是把特征图按照不同方向进行了平移,这等价于一个卷积操作

image.png

  • spatial-shift 操作是固定权重的,其余二者是可学习的
  • spatial-shift 操作以及 DW 都是局部感受野,Token-mixing MLP 是全局感受野
  • spatial-shift 操作以及 DW 对空间不敏感,而 Token-mixing MLP 对空间是敏感的
  • Token-mixing MLP 在通道上是一致的,DC 的卷积核每个通道是互异的,spatial-shift 操作可被视为一种特殊的组卷积,即一组内的卷积核是一样的


伪代码表示:

def spatial_shift(x):
  w,h,c = x.size()
  x[1:,:,:c/4] = x[:w-1,:,:c/4]
  x[:w-1,:,c/4:c/2] = x[1:,:,c/4:c/2]
  x[:,1:,c/2:c*3/4] = x[:,:h-1,c/2:c*3/4]
  x[:,:h-1,3*c/4:] = x[:,1:,3*c/4:]
  return x


详细内容:S2MLP网络详解


3. AS-MLP


出发点:没有局部信息交流的角度

对于MLP结构来说,模型通过矩阵转置与token-mixing投影操作获取全局的感受野,从而覆盖了长距离依赖。但是,从操作上可以看出来,MLP-Mixer比较少地利用了局部信息(局部信息就是cnn的归纳偏置,在构建cnn模型时比较重要),而且也不是所有的像素点都需要长距离依赖(这也vit模型目前被改进的一个方向,长距离依赖就是vit的归纳偏差,现在希望增加局部信息操作来减少参数量,部分论文已经证实了局部信息的重要性及优势)。


  • AS-MLP Block结构:

image.png

AS-MLP结构框图:

image.png

感受野概念:

image.png

伪代码表示:

# norm:normalizationlayer
# proj:channelprojection
# actn:activationlayer
import torch
import torch.nn.functional as F
def shift(x, dim):
  x = F.pad(x, "constant", 0)
  x = torch.chunk(x, shift_size, 1)
  x = [ torch.roll(x_c, shift, dim) for x_s, shift in zip(x, range(-pad, pad+1))]
  x = torch.cat(x, 1)
  return x[:, :, pad:-pad, pad:-pad]
def as_mlp_block(x):
  shortcut = x
  x = norm(x)
  x = actn(norm(proj(x)))
  x_lr = actn(proj(shift(x, 3)))
  x_td = actn(proj(shift(x, 2)))
  x = x_lr + x_td
  x = proj(norm(x))
  return x + shortcut


详细内容:论文阅读笔记 | MLP系列——AS-MLP


4. ViP


出发点:2维特征表示会丢失空域信息

Mixer对空间信息进行编码,首先将空间维度扁平化,然后沿空间维度进行线性投影(既对tokens×channels的形式进行运算)。而2D特征表示会导致所携带的空域信息丢失。Vision Permutator保持输入标记的原始空间维数,并沿着高度维数和宽度维数分别编码空域信息,以保留位置信息。Permute-MLP层最大的特点在于其中包含了为宽度方向、长度方向和通道方向独立建模的三个分支。

image.png

  • Permutator-MLP Block结构:

image.png

伪代码表示:

# H:height,W:width,C:channel,S:numberofsegments
# x:inputtensorofshape(H,W,C)
###################initialization####################################################
proj_h = nn.Linear(C, C)#Encodingspatialinformationalongtheheightdimension
proj_w = nn.Linear(C, C)#Encodingspatialinformationalongthewidthdimension
proj_c = nn.Linear(C, C)#Encodingchannelinformation
proj = nn.Linear(C, C)#Forinformationfusion
####################codeinforward##################################################
def permute_mlp(x):
  N = C // S
  x_h = x.reshape(H, W, N, S).permute(2, 1, 0, 3).reshape(N, W, H*S)
  x_h = self.proj_h(x_h).reshape(N, W, H, S).permute(2, 1, 0, 3).reshape(H, W, C)
  x_w = x.reshape(H, W, N, S).permute(0, 2, 1, 3).reshape(H, N, W*S)
  x_w = self.proj_w(x_w).reshape(H, N, W, S).permute(0, 2, 1, 3).reshape(H, W, C)
  x_c = self.proj_c(x)
  x = x_h + x_w + x_c
  x = self.proj(x)
  return x


如果要对 H 方向进行映射,那么首先就是进行特征矩阵的转置,即 (H,W,C) --> (C,W,H)。但是实际上作者分成了 S 段来实现。对于这段代码的理解,可以如下图所示:

image.png

最后将3个特征矩阵进行融合有多种简单的方式,其中的一种就是简单的相加,由于处理后的3个特征矩阵的shape都是相同的,都是HxWxC,所以可以做一个简单的相加处理,然后原文采取的方式是Split Attention(作者称为Weighted Permute-MLP),SE block 是对特征图每个通道算权重,Split Attention 是对于多个三维特征图算每个三维特征图的权重。


Split Attention过程描述如下,假设有k个H×W×C的特征矩阵 :

  • 首先是对于各个特征矩阵元素求和得到一个H×W×C的特征矩阵
  • 然后按照空间内进行全局平均池化 (等价于直接求和),将特征矩阵变为一个 C 维向量
  • 然后经过两个全连接层 C–>C’–> kC,实现中作者将 C‘ 确定为 C//4,Dropout,BN,ReLU 等就不做叙述,并将全连接层的输出拆分为 k 个 C 维向量
  • 对于这k个C维的向量中的每一个做softmax处理(也就是每个向量权重的总和为1)
  • 将每个 Split Attention 模块输入的每个 Feature Map 和计算出来的每个 split 的权重相乘,再将结果加和,得到最终的结果


详细内容:Vision Permutator 网络详解


5. S2-MLPv2


出发点:S2-MLP的感受野比较局限,而且没有保存中间位移点

重新设计了特征融合与空间位移的结构,将特征矩阵分成3个部分。其中两个部分采用相反的空间移动测量,最后一部分保持不变,使其可以保持中间的位移点,完善MLP中的感受野概念。

image.png

可以看见,最后3个相同shape的特征矩阵进行融合时,采用的还是Split Attention的方法(与Vip采用的方法相同)


  • S2-MLPv2 Block结构:

image.png

伪代码表示:

def spatial_shift1(x):
  b,w,h,c = x.size()
  x[:,1:,:,:c/4] = x[:,:w-1,:,:c/4]
  x[:,:w-1,:,c/4:c/2] = x[:,1:,:,c/4:c/2]
  x[:,:,1:,c/2:c*3/4] = x[:,:,:h-1,c/2:c*3/4]
  x[:,:,:h-1,3*c/4:] = x[:,:,1:,3*c/4:]
  return x
def spatial_shift2(x):
  b,w,h,c = x.size()
  x[:,:,1:,:c/4] = x[:,:,:h-1,:c/4]
  x[:,:,:h-1,c/4:c/2] = x[:,:,1:,c/4:c/2]
  x[:,1:,:,c/2:c*3/4] = x[:,:w-1,:,c/2:c*3/4]
  x[:,:w-1,:,3*c/4:] = x[:,1:,:,3*c/4:]
  return x
class S2-MLPv2(nn.Module):
  def __init__(self, channels):
  super().__init__()
  self.mlp1 = nn.Linear(channels,channels*3)
  self.mlp2 = nn.Linear(channels,channels)
  self.split_attention = SplitAttention()
  def forward(self, x):
  b,w,h,c = x.size()
  x = self.mlp1(x)
  x1 = spatial_shift1(x[:,:,:,:c/3])
  x2 = spatial_shift2(x[:,:,:,c/3:c/3*2])
  x3 = x[:,:,:,c/3*2:]
  a = self.split_attention(x1,x2,x3)
  x = self.mlp2(a)
  return x


详细内容:S2MLPv2 网络详解

目录
相关文章
|
2月前
|
机器学习/深度学习 存储 人工智能
白话文讲解大模型| Attention is all you need
本文档旨在详细阐述当前主流的大模型技术架构如Transformer架构。我们将从技术概述、架构介绍到具体模型实现等多个角度进行讲解。通过本文档,我们期望为读者提供一个全面的理解,帮助大家掌握大模型的工作原理,增强与客户沟通的技术基础。本文档适合对大模型感兴趣的人员阅读。
519 18
白话文讲解大模型| Attention is all you need
|
6月前
|
机器学习/深度学习
神经网络可能不再需要激活函数?Layer Normalization也具有非线性表达!
【7月更文挑战第14天】研究表明,层归一化(LayerNorm)可能具备非线性表达能力,挑战了神经网络对激活函数的依赖。在LN-Net结构中,仅使用线性层与LayerNorm就能实现复杂分类,其VC维度下界证明了非线性表达。尽管如此,是否能完全替代激活函数及如何有效利用这一特性仍需更多研究。[arXiv:2406.01255]
69 5
|
机器学习/深度学习 数据挖掘 PyTorch
图像分类经典神经网络大总结(AlexNet、VGG 、GoogLeNet 、ResNet、 DenseNet、SENet、ResNeXt )
图像分类经典神经网络大总结(AlexNet、VGG 、GoogLeNet 、ResNet、 DenseNet、SENet、ResNeXt )
7511 1
图像分类经典神经网络大总结(AlexNet、VGG 、GoogLeNet 、ResNet、 DenseNet、SENet、ResNeXt )
|
机器学习/深度学习 人工智能 自然语言处理
【Deep Learning 8】Self-Attention自注意力神经网络
🍊本文主要介绍了Self-Attention产生的背景以及解析了具体的网络模型。
132 0
|
机器学习/深度学习 编解码 算法
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(一)
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(一)
267 0
|
编解码 数据可视化 计算机视觉
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(二)
全新池化方法AdaPool | 让ResNet、DenseNet、ResNeXt等在所有下游任务轻松涨点(二)
293 0
|
机器学习/深度学习 自然语言处理 PyTorch
Attention-lvcsr、Residual LSTM…你都掌握了吗?一文总结语音识别必备经典模型(1)
Attention-lvcsr、Residual LSTM…你都掌握了吗?一文总结语音识别必备经典模型
230 0
|
机器学习/深度学习 人工智能 搜索推荐
Attention-lvcsr、Residual LSTM…你都掌握了吗?一文总结语音识别必备经典模型(3)
Attention-lvcsr、Residual LSTM…你都掌握了吗?一文总结语音识别必备经典模型
194 0
|
机器学习/深度学习 算法 语音技术
Attention-lvcsr、Residual LSTM…你都掌握了吗?一文总结语音识别必备经典模型(2)
Attention-lvcsr、Residual LSTM…你都掌握了吗?一文总结语音识别必备经典模型
237 0
|
机器学习/深度学习 异构计算
AlexNet相比传统的CNN(比如LeNet)有哪些重要改动呢:Data Augmentation,Dropout,(3) ReLU激活函数
AlexNet相比传统的CNN(比如LeNet)有哪些重要改动呢:Data Augmentation,Dropout,(3) ReLU激活函数
243 0
AlexNet相比传统的CNN(比如LeNet)有哪些重要改动呢:Data Augmentation,Dropout,(3) ReLU激活函数