总结几个比较好的CNN模块。
代码:
class SEBlock(nn.Module): def __init__(self, input_channels, internal_neurons): super(SEBlock, self).__init__() self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True, padding_mode='same') self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True, padding_mode='same') def forward(self, inputs): x = F.avg_pool2d(inputs, kernel_size=inputs.size(3)) x = self.down(x) x = F.leaky_relu(x) x = self.up(x) x = F.sigmoid(x) x = x.repeat(1, 1, inputs.size(2), inputs.size(3)) return inputs * x
- ACBlock
代码
class CropLayer(nn.Module): # E.g., (-1, 0) means this layer should crop the first and last rows of the feature map. And (0, -1) crops the first and last columns def __init__(self, crop_set): super(CropLayer, self).__init__() self.rows_to_crop = - crop_set[0] self.cols_to_crop = - crop_set[1] assert self.rows_to_crop >= 0 assert self.cols_to_crop >= 0 def forward(self, input): if self.rows_to_crop == 0 and self.cols_to_crop == 0: return input elif self.rows_to_crop > 0 and self.cols_to_crop == 0: return input[:, :, self.rows_to_crop:-self.rows_to_crop, :] elif self.rows_to_crop == 0 and self.cols_to_crop > 0: return input[:, :, :, self.cols_to_crop:-self.cols_to_crop] else: return input[:, :, self.rows_to_crop:-self.rows_to_crop, self.cols_to_crop:-self.cols_to_crop] class ACBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, groups=1, padding_mode='same', deploy=False, use_affine=True, reduce_gamma=False, use_last_bn=False, gamma_init=None): super(ACBlock, self).__init__() self.deploy = deploy if deploy: self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) else: self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode) self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine) center_offset_from_origin_border = padding - kernel_size // 2 ver_pad_or_crop = (padding, center_offset_from_origin_border) hor_pad_or_crop = (center_offset_from_origin_border, padding) if center_offset_from_origin_border >= 0: self.ver_conv_crop_layer = nn.Identity() ver_conv_padding = ver_pad_or_crop self.hor_conv_crop_layer = nn.Identity() hor_conv_padding = hor_pad_or_crop else: self.ver_conv_crop_layer = CropLayer(crop_set=ver_pad_or_crop) ver_conv_padding = (0, 0) self.hor_conv_crop_layer = CropLayer(crop_set=hor_pad_or_crop) hor_conv_padding = (0, 0) self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1), stride=stride, padding=ver_conv_padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode) self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size), stride=stride, padding=hor_conv_padding, dilation=dilation, groups=groups, bias=False, padding_mode=padding_mode) self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine) self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine) if reduce_gamma: assert not use_last_bn self.init_gamma(1.0 / 3) if use_last_bn: assert not reduce_gamma self.last_bn = nn.BatchNorm2d(num_features=out_channels, affine=True) if gamma_init is not None: assert not reduce_gamma self.init_gamma(gamma_init) def init_gamma(self, gamma_value): init.constant_(self.square_bn.weight, gamma_value) init.constant_(self.ver_bn.weight, gamma_value) init.constant_(self.hor_bn.weight, gamma_value) print('init gamma of square, ver and hor as ', gamma_value) def single_init(self): init.constant_(self.square_bn.weight, 1.0) init.constant_(self.ver_bn.weight, 0.0) init.constant_(self.hor_bn.weight, 0.0) print('init gamma of square as 1, ver and hor as 0') def forward(self, input): if self.deploy: return self.fused_conv(input) else: square_outputs = self.square_conv(input) square_outputs = self.square_bn(square_outputs) vertical_outputs = self.ver_conv_crop_layer(input) vertical_outputs = self.ver_conv(vertical_outputs) vertical_outputs = self.ver_bn(vertical_outputs) horizontal_outputs = self.hor_conv_crop_layer(input) horizontal_outputs = self.hor_conv(horizontal_outputs) horizontal_outputs = self.hor_bn(horizontal_outputs) result = square_outputs + vertical_outputs + horizontal_outputs if hasattr(self, 'last_bn'): return self.last_bn(result) return result
- ChannelAttention
1.class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) self.relu1 = nn.LeakyReLU(negative_slope=0.01, inplace=False) self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out = avg_out + max_out return self.sigmoid(out)
ConvBN
class ConvBN(nn.Sequential): def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1): if not isinstance(kernel_size, int): padding = [(i - 1) // 2 for i in kernel_size] else: padding = (kernel_size - 1) // 2 super(ConvBN, self).__init__(OrderedDict([ ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding=padding, groups=groups, bias=False)), ('bn', nn.BatchNorm2d(out_planes)), #('Mish', Mish()) ('Mish', nn.LeakyReLU(negative_slope=0.3, inplace=False)) ]))