3.14、测试函数
# 测试函数 def test(device, test_dataloader): fcrn_encode.eval() fcrn_decode.eval() # Gen.eval() for batch_idx, (road, road_label, img_name)in enumerate(test_dataloader): road, _ = road.to(device), road_label.to(device) # z = torch.randn(road.shape[0], 1, IMAGE_SCALE, IMAGE_SCALE, device=device) # img_noise = torch.cat((road, z), dim=1) # fake_feature = Gen(img_noise) feature, x2, x3, x4 = fcrn_encode(road) det_road = fcrn_decode(feature, x2, x3, x4) label = det_road.detach().cpu() label = np.transpose(np.array(utils.make_grid(label, padding=0, nrow=1)), (1, 2, 0)) # blur = cv2.GaussianBlur(label*255, (5, 5), 0) _, thresh = cv2.threshold(label*255, 200, 255, cv2.THRESH_BINARY) cv2.imwrite('./test/lab_dete_AVD/{}.png'.format(int(img_name[0])), thresh) print('testing...') print('{}/{}'.format(batch_idx, len(test_dataloader))) print('Done!') # 文件的读取与保存 def iou(path_img, path_lab, epoch): img_name = os.listdir(path_img) img_name.sort(key=lambda x:int(x[:-4])) print(img_name) iou_list = [] for i in range(len(img_name)): det = img_name[i] det = cv2.imread(path_img + '/' + det, 0) lab = img_name[i] lab = cv2.imread(path_lab + '/' + lab[:-4] + '.png', 0) lab = cv2.resize(lab, (opt.image_scale_w, opt.image_scale_h)) count0, count1, a, count2 = 0, 0, 0, 0 for j in range(det.shape[0]): for k in range(det.shape[1]): if det[j][k] != 0 and lab[j][k] != 0: count0 += 1 elif det[j][k] == 0 and lab[j][k] != 0: count1 += 1 elif det[j][k] != 0 and lab[j][k] == 0: count2 += 1 #iou = (count1 + count2)/(det.shape[0] * det.shape[1]) iou = count0/(count1 + count0 + count2 + 0.0001) iou_list.append(iou) print(img_name[i], ':', iou) print('mean_iou:', sum(iou_list)/len(iou_list)) with open('./munich_iou.txt',"a") as f: f.write("model_num" + " " + str(epoch) + " " + 'mean_iou:' + str(sum(iou_list)/len(iou_list)) + '\n')
3.15、主函数
# 主函数 if __name__ == '__main__': if opt.mode == 'train': num_epochs = opt.num_epochs for epoch in tqdm(range(num_epochs)): train(device, train_dataloader, epoch) Dis_scheduler.step() Gen_scheduler.step() encode_scheduler.step() decode_scheduler.step() if epoch % 50 == 0: now = time.strftime("%Y-%m-%d-%H_%M_%S",time.localtime(time.time())) torch.save(Dis.state_dict(), './model/Dis_{}'+ now +'munich.pkl'.format(opt.alpha)) torch.save(Gen.state_dict(), './model/Gen_{}'+ now +'munich.pkl'.format(opt.alpha)) torch.save(fcrn_decode.state_dict(), './model/fcrn_decode_{}'+ now +'munich.pkl'.format(opt.alpha)) torch.save(fcrn_encode.state_dict(), './model/fcrn_encode_{}'+ now +'munich.pkl'.format(opt.alpha)) print('testing...') test(device, test_dataloader) iou('./test/lab_dete_AVD', './munich/test/lab', epoch) if opt.mode == 'test': test(device, test_dataloader) iou('./test/lab_dete_AVD', './munich/test/lab', 'test')
四、 实验运行步骤与运行结果
4.1、 运行步骤
- 1
- 2
3
4
5
4.2、 运行的结果
1
2
3
4
5
-6
- 7
8
五、 实验总结
- 从运行结果可以看出,用Unet网络训练目标数据集,可以对数据集的道路目标实现准确的检测。
- 从大量的数据集中进行测试,在CPU上运行,Unet网络测试数据用了将近10小时的训练时间。但是,得到的目标检测的结果是非常准确的。