Yolov3主干网络Darknet53网络结构复现,非常基础的写法
from torch import nn from torch.nn import functional import torch class ConvolutionalLayers(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding,bias=False): super(ConvolutionalLayers, self).__init__() self.sub_module = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,bias=bias), nn.BatchNorm2d(out_channels), nn.LeakyReLU() ) def forward(self,x): return self.sub_module(x) class Residual(nn.Module): def __init__(self, in_channels): super(Residual, self).__init__() self.sub_module = nn.Sequential( ConvolutionalLayers(in_channels, in_channels // 2, 1, 1, 0), ConvolutionalLayers(in_channels // 2, in_channels, 3, 1, 1), ) def forward(self, x): return x + self.sub_module(x) class Convolutional_Set(nn.Module): def __init__(self, in_channels, out_channels): super(Convolutional_Set, self).__init__() self.sub_module = nn.Sequential( ConvolutionalLayers(in_channels, out_channels, 1, 1, 0), ConvolutionalLayers(out_channels, in_channels, 3, 1, 1), ConvolutionalLayers(in_channels, out_channels, 1, 1, 0), ConvolutionalLayers(out_channels, in_channels, 3, 1, 1), ConvolutionalLayers(in_channels, out_channels, 1, 1, 0), ) def forward(self, x): return self.sub_module(x) class UpSamplingLayers(nn.Module): def __init__(self): super(UpSamplingLayers, self).__init__() def forward(self,x): return functional.interpolate(x,scale_factor=2,mode='nearest') class Darknet53(nn.Module): def __init__(self): super(Darknet53, self).__init__() self.Residual_Block_52=nn.Sequential( ConvolutionalLayers(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1), ConvolutionalLayers(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), Residual(64), ConvolutionalLayers(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), Residual(128), Residual(128), ConvolutionalLayers(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), Residual(256), Residual(256), Residual(256), Residual(256), Residual(256), Residual(256), Residual(256), Residual(256), ) self.Residual_Block_26=nn.Sequential( ConvolutionalLayers(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1), Residual(512), Residual(512), Residual(512), Residual(512), Residual(512), Residual(512), Residual(512), Residual(512), ) self.Residual_Block_13 = nn.Sequential( ConvolutionalLayers(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1), Residual(1024), Residual(1024), Residual(1024), Residual(1024), ) #---------------------------------------------------------- self.convset_13=nn.Sequential( Convolutional_Set(1024,512) ) #Predict one self.detetion_13=nn.Sequential( ConvolutionalLayers(in_channels=512,out_channels=1024,kernel_size=3,stride=1,padding=1), nn.Conv2d(1024,24,1,1,0) ) self.up_13to26=nn.Sequential( ConvolutionalLayers(512,256,3,1,1), UpSamplingLayers() ) #--------------------------------------------------------- self.convset_26 = nn.Sequential( Convolutional_Set(768,256) ) # Predict two self.detetion_26 = nn.Sequential( ConvolutionalLayers(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), nn.Conv2d(512, 24, 1, 1, 0) ) self.up_26to52 = nn.Sequential( ConvolutionalLayers(256, 128, 3, 1, 1), UpSamplingLayers() ) #------------------------------------------------------------ self.convset_52 = nn.Sequential( Convolutional_Set(384, 128) ) # Predict three self.detetion_52 = nn.Sequential( ConvolutionalLayers(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), nn.Conv2d(256, 24, 1, 1, 0) ) def forward(self,x): Residual_output_52 = self.Residual_Block_52(x) Residual_output_26 = self.Residual_Block_26(Residual_output_52) Residual_output_13 = self.Residual_Block_13(Residual_output_26) convset_out_13 = self.convset_13(Residual_output_13) detetion_out_13 = self.detetion_13(convset_out_13) up_out_26 = self.up_13to26(convset_out_13) route_out_26 = torch.cat((up_out_26,Residual_output_26), dim=1) convset_out_26 = self.convset_26(route_out_26) detetion_out_26 = self.detetion_26(convset_out_26) up_out_52 = self.up_26to52(convset_out_26) route_out_52 = torch.cat((up_out_52, Residual_output_52), dim=1) convset_out_52 = self.convset_52(route_out_52) detetion_out_52 = self.detetion_52(convset_out_52) return detetion_out_13, detetion_out_26, detetion_out_52 if __name__ == '__main__': yolo = Darknet53() x = torch.randn(1, 3, 416, 416) y = yolo(x) print(y[0].shape) print(y[1].shape) print(y[2].shape)