前言
在卷积操作中,每个卷积核都是固定大小的,无法适应输入数据中的不同空间结构。为了解决这个问题,一种新型的卷积操作——SKCONV应运而生。SKCONV引入了动态卷积核形状的概念,能够自适应地学习卷积核的形状和大小,从而能够更好地适应输入数据的空间结构,提高模型的性能。本文将介绍SKCONV的原理、实现方式以及在不同任务中的应用效果,希望能够帮助读者更好地了解这一创新性的卷积操作。
原理
SKCONV是一种动态卷积操作,它能够根据输入数据的不同空间结构自适应地学习卷积核的形状和大小。SKCONV的原理是基于Split-and-Combine(拆分与合并)思想和Channel-wise交互(通道间交互)机制。SKCONV的拆分与合并思想指的是将输入特征图按通道数进行拆分,然后对每个通道进行卷积操作,最后再将不同通道的结果合并。SKCONV通过不同通道间的交互,自适应地学习卷积核的形状和大小,从而适应输入数据中的不同空间结构。
交互机制
SKCONV中的Channel-wise交互机制主要包括两个部分:组稀疏交互和通道注意力机制。组稀疏交互指的是将输入特征图按组进行拆分,并只允许不同组之间的交互,这样可以减少模型的计算量。通道注意力机制则是对不同通道的特征进行加权,让网络自动学习每个通道的重要性,从而提高特征图的表达能力。
精要概括
SKCONV的精髓部分是动态卷积核形状、Split-and-Combine思想、组稀疏交互和通道注意力机制。SKCONV的精髓部分主要包括以下几个方面:
- 动态卷积核形状:SKCONV能够根据输入数据的不同空间结构自适应地学习卷积核的形状和大小,从而能够更好地适应输入数据的空间结构,提高模型的性能。
- Split-and-Combine思想:SKCONV采用Split-and-Combine思想,将输入特征图按通道数进行拆分,然后对每个通道进行卷积操作,最后再将不同通道的结果合并。这种拆分与合并的方法能够在减少计算量的同时,提高特征图的表达能力。
- 组稀疏交互:SKCONV采用组稀疏交互机制,只允许不同组之间的交互,从而减少了模型的计算量。这种交互方式还能够帮助模型更好地捕捉不同通道之间的相关性。
- 通道注意力机制:SKCONV采用通道注意力机制,自动学习每个通道的重要性,从而提高特征图的表达能力。这种机制能够让模型更加关注重要的特征,抑制无用的特征,从而提高模型的性能。
综上所述,SKCONV能够自适应地学习卷积核的形状和大小,从而能够更好地适应输入数据的空间结构,提高模型的性能。同时,SKCONV还采用了组稀疏交互和通道注意力机制,进一步提高了模型的表达能力。
实现
SKCONV的实现方式主要包括Split、Group Convolution、Combine和Group Shuffle等步骤,其中Group Convolution是SKCONV的核心实现方式,它通过自适应学习卷积核的形状和大小,并采用组稀疏交互机制,从而提高特征图的表达能力,减少计算量,提高模型性能。
SKCONV的实现方式可以分为以下几个步骤:
- Split操作:首先将输入的特征图按通道数进行拆分,得到多个子特征图,每个子特征图只包含一个通道。
- Group Convolution操作:对每个子特征图进行卷积操作,使用多个不同形状和大小的卷积核进行卷积,得到多个卷积结果。
- Combine操作:将不同通道的卷积结果进行组合,得到最终的输出特征图。在组合过程中,可以采用通道注意力机制,自适应地学习每个通道的权重,从而提高特征图的表达能力。
- Group Shuffle操作:将组内的通道进行随机重排,从而增加模型的随机性和泛化能力。
python
复制代码
import torch from torch import nn class SKConv(nn.Module): def __init__(self, features, WH, M, G, r, stride=1, L=32): """ Constructor Args: features: input channel dimensionality. WH: input spatial dimensionality, used for GAP kernel size. M: the number of branchs. G: num of convolution groups. r: the radio for compute d, the length of z. stride: stride, default 1. L: the minimum dim of the vector z in paper, default 32. """ super(SKConv, self).__init__() d = max(int(features / r), L) self.M = M self.features = features self.convs = nn.ModuleList([]) for i in range(M): self.convs.append(nn.Sequential( nn.Conv2d(features, features, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G), nn.BatchNorm2d(features), nn.ReLU(inplace=False) )) # self.gap = nn.AvgPool2d(int(WH/stride)) self.fc = nn.Linear(features, d) self.fcs = nn.ModuleList([]) for i in range(M): self.fcs.append( nn.Linear(d, features) ) self.softmax = nn.Softmax(dim=1) def forward(self, x): for i, conv in enumerate(self.convs): fea = conv(x).unsqueeze_(dim=1) if i == 0: feas = fea else: feas = torch.cat([feas, fea], dim=1) fea_U = torch.sum(feas, dim=1) # fea_s = self.gap(fea_U).squeeze_() fea_s = fea_U.mean(-1).mean(-1) fea_z = self.fc(fea_s) for i, fc in enumerate(self.fcs): vector = fc(fea_z).unsqueeze_(dim=1) if i == 0: attention_vectors = vector else: attention_vectors = torch.cat([attention_vectors, vector], dim=1) attention_vectors = self.softmax(attention_vectors) attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) fea_v = (feas * attention_vectors).sum(dim=1) return fea_v class SKUnit(nn.Module): def __init__(self, in_features, out_features, WH, M, G, r, mid_features=None, stride=1, L=32): """ Constructor Args: in_features: input channel dimensionality. out_features: output channel dimensionality. WH: input spatial dimensionality, used for GAP kernel size. M: the number of branchs. G: num of convolution groups. r: the radio for compute d, the length of z. mid_features: the channle dim of the middle conv with stride not 1, default out_features/2. stride: stride. L: the minimum dim of the vector z in paper. """ super(SKUnit, self).__init__() if mid_features is None: mid_features = int(out_features / 2) self.feas = nn.Sequential( nn.Conv2d(in_features, mid_features, 1, stride=1), nn.BatchNorm2d(mid_features), SKConv(mid_features, WH, M, G, r, stride=stride, L=L), nn.BatchNorm2d(mid_features), nn.Conv2d(mid_features, out_features, 1, stride=1), nn.BatchNorm2d(out_features) ) if in_features == out_features: # when dim not change, in could be added diectly to out self.shortcut = nn.Sequential() else: # when dim not change, in should also change dim to be added to out self.shortcut = nn.Sequential( nn.Conv2d(in_features, out_features, 1, stride=stride), nn.BatchNorm2d(out_features) ) def forward(self, x): fea = self.feas(x) return fea + self.shortcut(x)
实验结果
下表是论文中提到的实验结果对比图,其中224×表示用于评估的单个224×224作物,同样是320×。需要注意的是,senet / sknet都是基于相应的ResNeXt主干
总结
针对上述的SKConv结构,大家可以自行嵌入到分类网络中以及检测任务、分割任务等。希望本篇文章对大家有所帮助!