一、 实验原理与目的
实验采用Unet目标检测网络实现对目标的检测。例如检测舰船、车辆、人脸、道路等。其中的Unet网络结构如下所示
U-Net 是一个 encoder-decoder 结构,左边一半的 encoder 包括若干卷积,池化,把图像进行下采样,右边的 decoder 进行上采样,恢复到原图的形状,给出每个像素的预测。
编码器有四个子模块,每个子模块包含两个卷积层,每个子模块之后有一个通过 maxpool 实现的下采样层。
输入图像的分辨率是 572x572, 第 1-5 个模块的分辨率分别是 572x572, 284x284, 140x140, 68x68 和 32x32。
解码器包含四个子模块,分辨率通过上采样操作依次上升,直到与输入图像的分辨率一致。该网络还使用了跳跃连接,将上采样结果与编码器中具有相同分辨率的子模块的输出进行连接,作为解码器中下一个子模块的输入。
架构中的一个重要修改部分是在上采样中还有大量的特征通道,这些通道允许网络将上下文信息传播到具有更高分辨率的层。因此,拓展路径或多或少地与收缩路径对称,并产生一个 U 形结构。
在该网络中没有任何完全连接的层,并且仅使用每个卷积的有效部分,即分割映射仅包含在输入图像中可获得完整上下文的像素。该策略允许通过重叠平铺策略对任意大小的图像进行无缝分割,如图所示。为了预测图像边界区域中的像素,通过镜像输入图像来推断缺失的上下文。这种平铺策略对于将网络应用于大型的图像非常重要,否则分辨率将受到 GPU 内存的限制。
二、 实验内容
本实验通过Unet网络,实现对道路目标的检测,测试的数据集存放于文件夹中。使用Unet网络得到训练的数据集:道路目标检测的结果。
三、 实验程序
3.1、导入库
# 导入库 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms, models, utils from torch.utils.data import DataLoader, Dataset, random_split from torch.utils.tensorboard import SummaryWriter #from torchsummary import summary import matplotlib.pyplot as plt import numpy as np import time import os import copy import cv2 import argparse # argparse库: 解析命令行参数 from tqdm import tqdm # 进度条
3.2、创建一个解析对象
# 创建一个解析对象 parser = argparse.ArgumentParser(description="Choose mode")
3.3、输入命令行和参数
# 输入命令行和参数 parser.add_argument('-mode', required=True, choices=['train', 'test'], default='train') parser.add_argument('-dim', type=int, default=16) parser.add_argument('-num_epochs', type=int, default=3) parser.add_argument('-image_scale_h', type=int, default=256) parser.add_argument('-image_scale_w', type=int, default=256) parser.add_argument('-batch', type=int, default=4) parser.add_argument('-img_cut', type=int, default=4) parser.add_argument('-lr', type=float, default=5e-5) parser.add_argument('-lr_1', type=float, default=5e-5) parser.add_argument('-alpha', type=float, default=0.05) parser.add_argument('-sa_scale', type=float, default=8) parser.add_argument('-latent_size', type=int, default=100) parser.add_argument('-data_path', type=str, default='./munich/train/img') parser.add_argument('-label_path', type=str, default='./munich/train/lab') parser.add_argument('-gpu', type=str, default='0') parser.add_argument('-load_model', required=True, choices=['True', 'False'], help='choose True or False', default='False')
3.4、parse_args()方法进行解析
# parse_args()方法进行解析 opt = parser.parse_args() print(opt) os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu use_cuda = torch.cuda.is_available() print("use_cuda:", use_cuda)
3.5、指定计算机的第一个设备是GPU
# 指定计算机的第一个设备是GPU device = torch.device("cuda" if use_cuda else "cpu") IMG_CUT = opt.img_cut LATENT_SIZE = opt.latent_size writer = SummaryWriter('./runs2/gx0102')
3.6、创建文件路径
# 创建文件路径 def auto_create_path(FilePath): if os.path.exists(FilePath): print(FilePath + ' dir exists') else: print(FilePath + ' dir not exists') os.makedirs(FilePath)
3.7、创建文件存放训练的结果
# 创建文件存放训练的结果 auto_create_path('./test/lab_dete_AVD') auto_create_path('./model') auto_create_path('./results')
3.8、向下采样,求剩余的区域
# 向下采样,求剩余的区域 class ResidualBlockClass(nn.Module): def __init__(self, name, input_dim, output_dim, resample=None, activate='relu'): super(ResidualBlockClass, self).__init__() self.name = name self.input_dim = input_dim self.output_dim = output_dim self.resample = resample self.batchnormlize_1 = nn.BatchNorm2d(input_dim) self.activate = activate if resample == 'down': self.conv_0 = nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1) self.conv_shortcut = nn.AvgPool2d(3, stride=2, padding=1) self.conv_1 = nn.Conv2d(in_channels=input_dim, out_channels=input_dim, kernel_size=3, stride=1, padding=1) self.conv_2 = nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=2, padding=1) self.batchnormlize_2 = nn.BatchNorm2d(input_dim) elif resample == 'up': self.conv_0 = nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1) self.conv_shortcut = nn.Upsample(scale_factor=2) self.conv_1 = nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1) self.conv_2 = nn.ConvTranspose2d(in_channels=output_dim, out_channels=output_dim, kernel_size=3, stride=2, padding=2, output_padding=1, dilation=2) self.batchnormlize_2 = nn.BatchNorm2d(output_dim) elif resample == None: self.conv_shortcut = nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1) self.conv_1 = nn.Conv2d(in_channels=input_dim, out_channels=input_dim, kernel_size=3, stride=1, padding=1) self.conv_2 = nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1) self.batchnormlize_2 = nn.BatchNorm2d(input_dim) else: raise Exception('invalid resample value') def forward(self, inputs): if self.output_dim == self.input_dim and self.resample == None: shortcut = inputs elif self.resample == 'down': x = self.conv_0(inputs) shortcut = self.conv_shortcut(x) elif self.resample == None: x = inputs shortcut = self.conv_shortcut(x) else: x = self.conv_0(inputs) shortcut = self.conv_shortcut(x) if self.activate == 'relu': x = inputs x = self.batchnormlize_1(x) x = F.relu(x) x = self.conv_1(x) x = self.batchnormlize_2(x) x = F.relu(x) x = self.conv_2(x) return shortcut + x else: x = inputs x = self.batchnormlize_1(x) x = F.leaky_relu(x) x = self.conv_1(x) x = self.batchnormlize_2(x) x = F.leaky_relu(x) x = self.conv_2(x) return shortcut + x class Self_Attn(nn.Module): """ Self attention Layer""" def __init__(self,in_dim,activation=None): super(Self_Attn,self).__init__() self.chanel_in = in_dim # self.activation = activation self.query_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim//opt.sa_scale, kernel_size = 1) self.key_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim//opt.sa_scale, kernel_size = 1) self.value_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim, kernel_size = 1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self,x): """ inputs : x : input feature maps( B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height) """ m_batchsize, C, width, height = x.size() proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X (W*H) X C proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) energy = torch.bmm(proj_query,proj_key) # transpose check attention = self.softmax(energy) # BX (N) X (N) proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N out = torch.bmm(proj_value,attention.permute(0,2,1)) out = out.view(m_batchsize, C, width, height) out = self.gamma*out + x return out
3.9、上采样,使用卷积恢复区域
# 上采样,使用卷积恢复区域 class UpProject(nn.Module): def __init__(self, in_channels, out_channels): super(UpProject, self).__init__() # self.batch_size = batch_size self.conv1_1 = nn.Conv2d(in_channels, out_channels, 3) self.conv1_2 = nn.Conv2d(in_channels, out_channels, (2, 3)) self.conv1_3 = nn.Conv2d(in_channels, out_channels, (3, 2)) self.conv1_4 = nn.Conv2d(in_channels, out_channels, 2) self.conv2_1 = nn.Conv2d(in_channels, out_channels, 3) self.conv2_2 = nn.Conv2d(in_channels, out_channels, (2, 3)) self.conv2_3 = nn.Conv2d(in_channels, out_channels, (3, 2)) self.conv2_4 = nn.Conv2d(in_channels, out_channels, 2) self.bn1_1 = nn.BatchNorm2d(out_channels) self.bn1_2 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) def forward(self, x): # b, 10, 8, 1024 batch_size = x.shape[0] out1_1 = self.conv1_1(nn.functional.pad(x, (1, 1, 1, 1))) out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding #out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github out1_3 = self.conv1_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding #out1_3 = self.conv1_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github out1_4 = self.conv1_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding #out1_4 = self.conv1_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github out2_1 = self.conv2_1(nn.functional.pad(x, (1, 1, 1, 1))) out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding #out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github out2_3 = self.conv2_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding #out2_3 = self.conv2_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github out2_4 = self.conv2_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding #out2_4 = self.conv2_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github height = out1_1.size()[2] width = out1_1.size()[3] out1_1_2 = torch.stack((out1_1, out1_2), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( batch_size, -1, height, width * 2) out1_3_4 = torch.stack((out1_3, out1_4), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( batch_size, -1, height, width * 2) out1_1234 = torch.stack((out1_1_2, out1_3_4), dim=-3).permute(0, 1, 3, 2, 4).contiguous().view( batch_size, -1, height * 2, width * 2) out2_1_2 = torch.stack((out2_1, out2_2), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( batch_size, -1, height, width * 2) out2_3_4 = torch.stack((out2_3, out2_4), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( batch_size, -1, height, width * 2) out2_1234 = torch.stack((out2_1_2, out2_3_4), dim=-3).permute(0, 1, 3, 2, 4).contiguous().view( batch_size, -1, height * 2, width * 2) out1 = self.bn1_1(out1_1234) out1 = self.relu(out1) out1 = self.conv3(out1) out1 = self.bn2(out1) out2 = self.bn1_2(out2_1234) out = out1 + out2 out = self.relu(out) return out #编码,下采样 class Fcrn_encode(nn.Module): def __init__(self, dim=opt.dim): super(Fcrn_encode, self).__init__() self.dim = dim self.conv_1 = nn.Conv2d(in_channels=3, out_channels=dim, kernel_size=3, stride=1, padding=1) self.residual_block_1_down_1 = ResidualBlockClass('Detector.Res1', 1*dim, 2*dim, resample='down', activate='leaky_relu') # 128x128 self.residual_block_2_down_1 = ResidualBlockClass('Detector.Res2', 2*dim, 4*dim, resample='down', activate='leaky_relu') #64x64 self.residual_block_3_down_1 = ResidualBlockClass('Detector.Res3', 4*dim, 4*dim, resample='down', activate='leaky_relu') #32x32 self.residual_block_4_down_1 = ResidualBlockClass('Detector.Res4', 4*dim, 6*dim, resample='down', activate='leaky_relu') #16x16 self.residual_block_5_none_1 = ResidualBlockClass('Detector.Res5', 6*dim, 6*dim, resample=None, activate='leaky_relu') def forward(self, x, n1=0, n2=0, n3=0): x1 = self.conv_1(x)#x1:dimx256x256 x2 = self.residual_block_1_down_1(x1)#x2:2dimx128x128 x3 = self.residual_block_2_down_1((1-opt.alpha)*x2+opt.alpha*n1)#x3:4dimx64x64 x4 = self.residual_block_3_down_1((1-opt.alpha)*x3+opt.alpha*n2)#x4:4dimx32x32 x = self.residual_block_4_down_1((1-opt.alpha)*x4+opt.alpha*n3) feature = self.residual_block_5_none_1(x) x = F.tanh(feature) return x, x2, x3, x4
3.10、解码, 上采样
# 解码, 上采样 class Fcrn_decode(nn.Module): def __init__(self, dim=opt.dim): super(Fcrn_decode, self).__init__() self.dim = dim self.conv_2 = nn.Conv2d(in_channels=dim, out_channels=1, kernel_size=3, stride=1, padding=1) self.residual_block_6_none_1 = ResidualBlockClass('Detector.Res6', 6*dim, 6*dim, resample=None, activate='leaky_relu') # self.residual_block_7_up_1 = ResidualBlockClass('Detector.Res7', 6*dim, 6*dim, resample='up', activate='leaky_relu') self.sa_0 = Self_Attn(6*dim) #32x32 self.UpProject_1 = UpProject(6*dim, 4*dim) self.residual_block_8_up_1 = ResidualBlockClass('Detector.Res8', 6*dim, 4*dim, resample='up', activate='leaky_relu') self.sa_1 = Self_Attn(4*dim) #64x64 self.UpProject_2 = UpProject(2*4*dim, 4*dim) self.sa_2 = Self_Attn(4*dim) self.residual_block_9_up_1 = ResidualBlockClass('Detector.Res9', 4*dim, 4*dim, resample='up', activate='leaky_relu') #128x128 self.UpProject_3 = UpProject(2*4*dim, 2*dim) self.sa_3 = Self_Attn(2*dim) self.residual_block_10_up_1 = ResidualBlockClass('Detector.Res10', 4*dim, 2*dim, resample='up', activate='leaky_relu') #256x256 self.UpProject_4 = UpProject(2*2*dim, 1*dim) self.sa_4 = Self_Attn(1*dim) self.residual_block_11_up_1 = ResidualBlockClass('Detector.Res11', 2*dim, 1*dim, resample='up', activate='leaky_relu') def forward(self, x, x2, x3, x4): x = self.residual_block_6_none_1(x) x = self.UpProject_1(x) x = self.sa_1(x) x = self.UpProject_2(torch.cat((x, x4), dim=1)) x = self.sa_2(x) x = self.UpProject_3(torch.cat((x, x3), dim=1)) # x = self.sa_3(x) x = self.UpProject_4(torch.cat((x, x2), dim=1)) # x = self.sa_4(x) x = F.normalize(x, dim=[0, 2, 3]) x = F.leaky_relu(x) x = self.conv_2(x) x = F.sigmoid(x) return x class Generator(nn.Module): def __init__(self, dim=opt.dim): super(Generator, self).__init__() self.dim = dim self.conv_1 = nn.Conv2d(in_channels=4, out_channels=1*dim, kernel_size=3, stride=1, padding=1) self.conv_2 = nn.Conv2d(in_channels=dim, out_channels=3, kernel_size=3, stride=1, padding=1) self.batchnormlize = nn.BatchNorm2d(1*dim) self.residual_block_1 = ResidualBlockClass('G.Res1', 1*dim, 2*dim, resample='down') #128x128 self.residual_block_2 = ResidualBlockClass('G.Res2', 2*dim, 4*dim, resample='down') #64x64 # self.residual_block_2_1 = ResidualBlockClass('G.Res2_1', 4*dim, 4*dim, resample='down') #64x64 #self.residual_block_2_2 = ResidualBlockClass('G.Res2_2', 4*dim, 4*dim, resample=None) #64x64 self.residual_block_3 = ResidualBlockClass('G.Res3', 4*dim, 4*dim, resample='down') #32x32 self.residual_block_4 = ResidualBlockClass('G.Res4', 4*dim, 6*dim, resample='down') #16x16 self.residual_block_5 = ResidualBlockClass('G.Res5', 6*dim, 6*dim, resample=None) #16x16 self.residual_block_6 = ResidualBlockClass('G.Res6', 6*dim, 6*dim, resample=None) def forward(self, x): x = self.conv_1(x) x1 = self.residual_block_1(x)#x1:2*dimx128x128 x2 = self.residual_block_2(x1)#x2:4*dimx64x64 # x = self.residual_block_2_1(x) #x = self.residual_block_2_2(x) x3 = self.residual_block_3(x2)#x3:4*dimx32x32 x = self.residual_block_4(x3)#x4:6*dimx16x16 x = self.residual_block_5(x) x = self.residual_block_6(x) x = F.tanh(x) return x, x1, x2, x3 class Discriminator(nn.Module): def __init__(self, dim=opt.dim): super(Discriminator, self).__init__() self.dim = dim self.conv_1 = nn.Conv2d(in_channels=6*dim, out_channels=6*dim, kernel_size=3, stride=1, padding=1) #16x16 self.conv_2 = nn.Conv2d(in_channels=6*dim, out_channels=6*dim, kernel_size=3, stride=1, padding=1) self.conv_3 = nn.Conv2d(in_channels=6*dim, out_channels=4*dim, kernel_size=3, stride=1, padding=1) self.bn_1 = nn.BatchNorm2d(6*dim) self.conv_4 = nn.Conv2d(in_channels=4*dim, out_channels=4*dim, kernel_size=3, stride=2, padding=1) #8x8 self.conv_5 = nn.Conv2d(in_channels=4*dim, out_channels=4*dim, kernel_size=3, stride=1, padding=1) #8x8 self.conv_6 = nn.Conv2d(in_channels=4*dim, out_channels=2*dim, kernel_size=3, stride=2, padding=1) #4x4 self.bn_2 = nn.BatchNorm2d(2*dim) self.conv_7 = nn.Conv2d(in_channels=2*dim, out_channels=2*dim, kernel_size=3, stride=1, padding=1) #4x4 self.conv_8 = nn.Conv2d(in_channels=2*dim, out_channels=1*dim, kernel_size=3, stride=1, padding=1) #4x4 #self.conv_9 = nn.Conv2d(in_channels=1*dim, out_channels=1, kernel_size=4, stride=1, padding=(0, 1), dilation=(1, 3)) #1x1 def forward(self, x): x = F.leaky_relu(self.conv_1(x), negative_slope=0.02) x = F.leaky_relu(self.conv_2(x), negative_slope=0.02) x = F.leaky_relu(self.conv_3(x), negative_slope=0.02) # x = F.leaky_relu(self.bn_1(x), negative_slope=0.02) x = F.leaky_relu(self.conv_4(x), negative_slope=0.02) x = F.leaky_relu(self.conv_5(x), negative_slope=0.02) x = F.leaky_relu(self.conv_6(x), negative_slope=0.02) # x = F.leaky_relu(self.bn_2(x), negative_slope=0.2) x = F.leaky_relu(self.conv_7(x), negative_slope=0.02) x = F.leaky_relu(self.conv_8(x), negative_slope=0.02) #x = self.conv_9(x) x = torch.mean(x, dim=[1, 2, 3]) x = F.sigmoid(x) return x.view(-1, 1).squeeze() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ])
3.11、获取训练的数据集
# 获取训练的数据集 class GAN_Dataset(Dataset): def __init__(self, transform=None): self.transform = transform def __len__(self): return len(os.listdir(opt.data_path)) def __getitem__(self, idx): img_name = os.listdir(opt.data_path)[idx] imgA = cv2.imread(opt.data_path + '/' + img_name) imgA = cv2.resize(imgA, (opt.image_scale_w, opt.image_scale_h)) imgB = cv2.imread(opt.label_path + '/' + img_name[:-4] + '.png', 0) imgB = cv2.resize(imgB, (opt.image_scale_w, opt.image_scale_h)) # imgB[imgB>30] = 255 imgB = imgB/255 #imgB = imgB.astype('uint8') imgB = torch.FloatTensor(imgB) imgB = torch.unsqueeze(imgB, 0) #print(imgB.shape) if self.transform: imgA = self.transform(imgA) return imgA, imgB img_road = GAN_Dataset(transform) train_dataloader = DataLoader(img_road, batch_size=opt.batch, shuffle=True) print(len(train_dataloader.dataset), train_dataloader.dataset[7][1].shape)
3.12、测试数据集
# 测试数据集 class test_Dataset(Dataset): # DATA_PATH = './test/img' # LABEL_PATH = './test/lab' def __init__(self, transform=None): self.transform = transform def __len__(self): return len(os.listdir('./munich/test/img')) def __getitem__(self, idx): img_name = os.listdir('./munich/test/img') img_name.sort(key=lambda x:int(x[:-4])) img_name = img_name[idx] imgA = cv2.imread('./munich/test/img' + '/' + img_name) imgA = cv2.resize(imgA, (opt.image_scale_w, opt.image_scale_h)) imgB = cv2.imread('./munich/test/lab' + '/' + img_name[:-4] + '.png', 0) imgB = cv2.resize(imgB, (opt.image_scale_w, opt.image_scale_h)) #imgB = imgB/255 # imgB[imgB>30] = 255 imgB = imgB/255 #imgB = imgB.astype('uint8') imgB = torch.FloatTensor(imgB) imgB = torch.unsqueeze(imgB, 0) #print(imgB.shape) if self.transform: #imgA = imgA/255 #imgA = np.transpose(imgA, (2, 0, 1)) #imgA = torch.FloatTensor(imgA) imgA = self.transform(imgA) return imgA, imgB, img_name[:-4] img_road_test = test_Dataset(transform) test_dataloader = DataLoader(img_road_test, batch_size=1, shuffle=False) print(len(test_dataloader.dataset), test_dataloader.dataset[7][1].shape) loss = nn.BCELoss() fcrn_encode = Fcrn_encode() fcrn_encode = nn.DataParallel(fcrn_encode) fcrn_encode = fcrn_encode.to(device) if opt.load_model == 'True': fcrn_encode.load_state_dict(torch.load('./model/fcrn_encode_{}_link.pkl'.format(opt.alpha))) fcrn_decode = Fcrn_decode() fcrn_decode = nn.DataParallel(fcrn_decode) fcrn_decode = fcrn_decode.to(device) if opt.load_model == 'True': fcrn_decode.load_state_dict(torch.load('./model/fcrn_decode_{}_link.pkl'.format(opt.alpha))) Gen = Generator() Gen = nn.DataParallel(Gen) Gen = Gen.to(device) if opt.load_model == 'True': Gen.load_state_dict(torch.load('./model/Gen_{}_link.pkl'.format(opt.alpha))) Dis = Discriminator() Dis = nn.DataParallel(Dis) Dis = Dis.to(device) if opt.load_model == 'True': Dis.load_state_dict(torch.load('./model/Dis_{}_link.pkl'.format(opt.alpha))) Dis_optimizer = optim.Adam(Dis.parameters(), lr=opt.lr_1) Dis_scheduler = optim.lr_scheduler.StepLR(Dis_optimizer,step_size=800,gamma = 0.5) Fcrn_encode_optimizer = optim.Adam(fcrn_encode.parameters(), lr=opt.lr) encode_scheduler = optim.lr_scheduler.StepLR(Fcrn_encode_optimizer,step_size=300,gamma = 0.5) Fcrn_decode_optimizer = optim.Adam(fcrn_decode.parameters(), lr=opt.lr) decode_scheduler = optim.lr_scheduler.StepLR(Fcrn_decode_optimizer,step_size=300,gamma = 0.5) Gen_optimizer = optim.Adam(Gen.parameters(), lr=opt.lr_1) Gen_scheduler = optim.lr_scheduler.StepLR(Gen_optimizer,step_size=800,gamma = 0.5)
3.13、训练函数
# 训练函数 def train(device, train_dataloader, epoch): fcrn_encode.train() fcrn_decode.train() # Gen.train() for batch_idx, (road, road_label)in enumerate(train_dataloader): road, road_label = road.to(device), road_label.to(device) z = torch.randn(road.shape[0], 1, opt.image_scale_h, opt.image_scale_w, device=device) img_noise = torch.cat((road, z), dim=1) fake_feature, n1, n2, n3 = Gen(img_noise) feature, x2, x3, x4 = fcrn_encode(road, n1, n2, n3) Dis_optimizer.zero_grad() d_real = Dis(feature.detach()) d_loss_real = loss(d_real, 0.9*torch.ones_like(d_real)) d_fake = Dis((1-opt.alpha)*feature.detach() + opt.alpha*fake_feature.detach()) d_loss_fake = loss(d_fake, 0.1 + torch.zeros_like(d_fake)) d_loss = d_loss_real + d_loss_fake d_loss.backward() Dis_optimizer.step() Gen_optimizer.zero_grad() z = torch.randn(road.shape[0], 1, opt.image_scale_h, opt.image_scale_w, device=device) img_noise = torch.cat((road, z), dim=1) fake_feature, n1, n2, n3 = Gen(img_noise) detect_noise = fcrn_decode((1-opt.alpha)*feature.detach() + opt.alpha*fake_feature, x2, x3, x4) d_fake = Dis((1-opt.alpha)*feature.detach() + opt.alpha*fake_feature) g_loss = loss(d_fake, 0.9*torch.ones_like(d_fake)) g_loss -= loss(detect_noise, road_label) g_loss.backward() Gen_optimizer.step() z = torch.randn(road.shape[0], 1, opt.image_scale_h, opt.image_scale_w, device=device) img_noise = torch.cat((road, z), dim=1) fake_feature, n1, n2, n3 = Gen(img_noise) # feature_img = fake_feature.detach().cpu() # feature_img = np.transpose(np.array(utils.make_grid(feature_img, nrow=IMG_CUT)), (1, 2, 0)) feature, x2, x3, x4 = fcrn_encode(road, n1, n2, n3) #detect = fcrn_decode(0.9*feature + 0.1*fake_feature) detect = fcrn_decode(feature, x2, x3, x4 ) # detect_img = detect.detach().cpu() # detect_img = np.transpose(np.array(utils.make_grid(detect_img, nrow=IMG_CUT)), (1, 2, 0)) # blur = cv2.GaussianBlur(detect_img*255, (3, 3), 0) # _, thresh = cv2.threshold(blur,120,255,cv2.THRESH_BINARY) fcrn_loss = loss(detect, road_label) fcrn_loss += torch.mean(torch.abs(detect-road_label))/(torch.mean(torch.abs(detect+road_label))+0.001) Fcrn_encode_optimizer.zero_grad() Fcrn_decode_optimizer.zero_grad() fcrn_loss.backward() Fcrn_encode_optimizer.step() Fcrn_decode_optimizer.step() z = torch.randn(road.shape[0], 1, opt.image_scale_h, opt.image_scale_w, device=device) img_noise = torch.cat((road, z), dim=1) fake_feature, n1, n2, n3 = Gen(img_noise) # ffp, _ = torch.split(fake_feature, [3, 6*opt.dim-3], dim=1) # fake_feature_np = ffp.detach().cpu() # fake_feature_np = np.transpose(np.array(utils.make_grid(fake_feature_np, nrow=IMG_CUT, padding=0)), (1, 2, 0)) feature, x2, x3, x4 = fcrn_encode(road, n1, n2, n3) # fp, _ = torch.split(feature, [3, 6*opt.dim-3], dim=1) # feature_np = fp.detach().cpu() # feature_np = np.transpose(np.array(utils.make_grid(feature_np, nrow=IMG_CUT, padding=0)), (1, 2, 0)) road_np = road.detach().cpu() road_np = np.transpose(np.array(utils.make_grid(road_np, nrow=IMG_CUT, padding=0)), (1, 2, 0)) road_label_np = road_label.detach().cpu() road_label_np = np.transpose(np.array(utils.make_grid(road_label_np, nrow=IMG_CUT, padding=0)), (1, 2, 0)) detect_noise = fcrn_decode((1-opt.alpha)*feature + opt.alpha*fake_feature.detach(), x2, x3, x4 ) detect_noise_np = detect_noise.detach().cpu() detect_noise_np = np.transpose(np.array(utils.make_grid(detect_noise_np, nrow=IMG_CUT, padding=0)), (1, 2, 0)) blur = cv2.GaussianBlur(detect_noise_np*255, (3, 3), 0) _, thresh = cv2.threshold(blur,120,255,cv2.THRESH_BINARY) fcrn_loss1 = loss(detect_noise, road_label) fcrn_loss1 += torch.mean(torch.abs(detect_noise-road_label))/(torch.mean(torch.abs(detect_noise+road_label))+0.001) Fcrn_decode_optimizer.zero_grad() Fcrn_encode_optimizer.zero_grad() fcrn_loss1.backward() Fcrn_decode_optimizer.step() Fcrn_encode_optimizer.step() writer.add_scalar('g_loss', g_loss.data.item(), global_step = batch_idx) writer.add_scalar('d_loss', d_loss.data.item(), global_step = batch_idx) writer.add_scalar('Fcrn_loss', fcrn_loss1.data.item(), global_step = batch_idx) if batch_idx % 20 == 0: tqdm.write('[{}/{}] [{}/{}] Loss_Dis: {:.6f} Loss_Gen: {:.6f} Loss_Fcrn_encode: {:.6f} Loss_Fcrn_decode: {:.6f}' .format(epoch, num_epochs, batch_idx, len(train_dataloader), d_loss.data.item(), g_loss.data.item(), (fcrn_loss.data.item())/2, (fcrn_loss1.data.item())/2)) if batch_idx % 300 == 0: mix = np.concatenate(((road_np+1)*255/2, road_label_np*255, detect_noise_np*255), axis=0) # feature_np = cv2.resize((feature_np + 1)*255/2, (opt.image_scale_w, opt.image_scale_h)) # fake_feature_np = cv2.resize((fake_feature_np + 1)*255/2, (opt.image_scale_w, opt.image_scale_h)) # mix1 = np.concatenate((feature_np, fake_feature_np), axis=0) cv2.imwrite("./results/dete{}_{}.png".format(epoch, batch_idx), mix) # cv2.imwrite('./results_fcrn_noise/feature{}_{}.png'.format(epoch, batch_idx), mix1) # cv2.imwrite("./results/feature{}_{}.png".format(epoch, batch_idx), (feature_img + 1)*255/2) # cv2.imwrite("./results9/label{}_{}.png".format(epoch, batch_idx), np.transpose(road_label.cpu().numpy(), (2, 0, 1))*255)