模式识别与图像处理课程实验二:基于UNet的目标检测网络(下)

简介: 模式识别与图像处理课程实验二:基于UNet的目标检测网络(下)

3.14、测试函数


# 测试函数
def test(device, test_dataloader):
    fcrn_encode.eval()
    fcrn_decode.eval()
#     Gen.eval()
    for batch_idx, (road, road_label, img_name)in enumerate(test_dataloader):
        road, _ = road.to(device), road_label.to(device)
        # z = torch.randn(road.shape[0], 1, IMAGE_SCALE, IMAGE_SCALE, device=device)
        # img_noise = torch.cat((road, z), dim=1)
        # fake_feature = Gen(img_noise)
        feature, x2, x3, x4  = fcrn_encode(road)
        det_road = fcrn_decode(feature, x2, x3, x4)
        label = det_road.detach().cpu()
        label = np.transpose(np.array(utils.make_grid(label, padding=0, nrow=1)), (1, 2, 0))
        # blur = cv2.GaussianBlur(label*255, (5, 5), 0)
        _, thresh = cv2.threshold(label*255, 200, 255, cv2.THRESH_BINARY)
        cv2.imwrite('./test/lab_dete_AVD/{}.png'.format(int(img_name[0])), thresh)
        print('testing...')
        print('{}/{}'.format(batch_idx, len(test_dataloader)))
    print('Done!')
# 文件的读取与保存
def iou(path_img, path_lab, epoch):
    img_name = os.listdir(path_img)
    img_name.sort(key=lambda x:int(x[:-4]))
    print(img_name)
    iou_list = []
    for i in range(len(img_name)):
        det = img_name[i]
        det = cv2.imread(path_img + '/' + det, 0)
        lab = img_name[i]
        lab = cv2.imread(path_lab + '/' + lab[:-4] + '.png', 0)
        lab = cv2.resize(lab, (opt.image_scale_w, opt.image_scale_h))
        count0, count1, a, count2 = 0, 0, 0, 0
        for j in range(det.shape[0]):
            for k in range(det.shape[1]):
                if det[j][k] != 0 and lab[j][k] != 0:
                    count0 += 1
                elif det[j][k] == 0 and lab[j][k] != 0:
                    count1 += 1
                elif det[j][k] != 0 and lab[j][k] == 0:
                    count2 += 1
                #iou = (count1 + count2)/(det.shape[0] * det.shape[1])
                iou = count0/(count1 + count0 + count2 + 0.0001)
        iou_list.append(iou)
        print(img_name[i], ':', iou)
    print('mean_iou:', sum(iou_list)/len(iou_list))
    with open('./munich_iou.txt',"a") as f:
        f.write("model_num" + " " + str(epoch) + " " + 'mean_iou:' + str(sum(iou_list)/len(iou_list)) + '\n')

3.15、主函数


# 主函数
if __name__ == '__main__':
    if opt.mode == 'train':
        num_epochs = opt.num_epochs
        for epoch in tqdm(range(num_epochs)):
            train(device, train_dataloader, epoch)
            Dis_scheduler.step()
            Gen_scheduler.step()
            encode_scheduler.step()
            decode_scheduler.step()
            if epoch % 50 == 0:
                now = time.strftime("%Y-%m-%d-%H_%M_%S",time.localtime(time.time()))
                torch.save(Dis.state_dict(), './model/Dis_{}'+ now +'munich.pkl'.format(opt.alpha))
                torch.save(Gen.state_dict(), './model/Gen_{}'+ now +'munich.pkl'.format(opt.alpha))
                torch.save(fcrn_decode.state_dict(), './model/fcrn_decode_{}'+ now +'munich.pkl'.format(opt.alpha))
                torch.save(fcrn_encode.state_dict(), './model/fcrn_encode_{}'+ now +'munich.pkl'.format(opt.alpha))
                print('testing...')
                test(device, test_dataloader)
                iou('./test/lab_dete_AVD', './munich/test/lab', epoch)
    if opt.mode == 'test':
        test(device, test_dataloader)
        iou('./test/lab_dete_AVD', './munich/test/lab', 'test')


四、 实验运行步骤与运行结果

4.1、 运行步骤

  • 1


364bfc89eca14062a5c0f38e8b67ed7e.png


  • 2


06b024c383dc4bd0baa4cb1d639b63cf.png


3


12038bf571874e5a9c4d9d8423815a3c.png


4


c124ba51b78645c88b8a0348e5377115.png


5


be9df6a2cf7c4802b8b1b9e094c6bac2.png


4.2、 运行的结果

1

fa5946bd93c647938b5199f6b7261067.png


2


ae5f895078264912849e9c980cff07b0.png


3


c36890a9a2364595a4f8d8f15f8d386f.png

4


d170b6cdb0c548a3b6edf485e1534490.png


5


af4dcee3eafd4981939a52d96a34aae4.png

-6


2ff0a988fe114567b8b0cc5b9f803c41.png

  • 7


eb9086280a2b47d9bf1ce6442761c0a9.png


8

61d96a40f7744650a6a5100b7fc0659f.png



五、 实验总结


  • 从运行结果可以看出,用Unet网络训练目标数据集,可以对数据集的道路目标实现准确的检测。
  • 从大量的数据集中进行测试,在CPU上运行,Unet网络测试数据用了将近10小时的训练时间。但是,得到的目标检测的结果是非常准确的。



31a3b3bafdd34abd8155ab4a5cb14a7f.jpg

相关文章
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
106 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
1月前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
63 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
24天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLO11改进 - C3k2融合】C3k2DWRSeg二次创新C3k2_DWR:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目标检测
【YOLO11改进 - C3k2融合】C3k2DWRSeg二次创新C3k2_DWR:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目DWRSeg是一种高效的实时语义分割网络,通过将多尺度特征提取分为区域残差化和语义残差化两步,提高了特征提取效率。它引入了Dilation-wise Residual (DWR) 和 Simple Inverted Residual (SIR) 模块,优化了不同网络阶段的感受野。在Cityscapes和CamVid数据集上的实验表明,DWRSeg在准确性和推理速度之间取得了最佳平衡,达到了72.7%的mIoU,每秒319.5帧。代码和模型已公开。
【YOLO11改进 - C3k2融合】C3k2DWRSeg二次创新C3k2_DWR:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目标检测
|
1月前
|
机器学习/深度学习 网络架构 计算机视觉
目标检测笔记(一):不同模型的网络架构介绍和代码
这篇文章介绍了ShuffleNetV2网络架构及其代码实现,包括模型结构、代码细节和不同版本的模型。ShuffleNetV2是一个高效的卷积神经网络,适用于深度学习中的目标检测任务。
74 1
目标检测笔记(一):不同模型的网络架构介绍和代码
|
1月前
|
网络协议 网络虚拟化 网络架构
【网络实验】/主机/路由器/交换机/网关/路由协议/RIP+OSPF/DHCP(上)
【网络实验】/主机/路由器/交换机/网关/路由协议/RIP+OSPF/DHCP(上)
67 1
|
24天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLO11改进 - C3k2融合】C3k2融合DWRSeg二次创新C3k2_DWRSeg:扩张式残差分割网络,提高特征提取效率和多尺度信息获取能力,助力小目标检测
【YOLO11改进 - C3k2融合】C3k2融合DWRSDWRSeg是一种高效的实时语义分割网络,通过将多尺度特征提取方法分解为区域残差化和语义残差化两步,提高了多尺度信息获取的效率。网络设计了Dilation-wise Residual (DWR) 和 Simple Inverted Residual (SIR) 模块,分别用于高阶段和低阶段,以充分利用不同感受野的特征图。实验结果表明,DWRSeg在Cityscapes和CamVid数据集上表现出色,以每秒319.5帧的速度在NVIDIA GeForce GTX 1080 Ti上达到72.7%的mIoU,超越了现有方法。代码和模型已公开。
|
2月前
|
网络架构
静态路由 网络实验
本文介绍了如何通过配置静态路由实现不同网络设备间的通信,包括网络拓扑图、设备IP配置、查看路由表信息、配置静态路由和测试步骤。通过在路由器上设置静态路由,使得不同子网内的设备能够互相通信。
静态路由 网络实验
|
1月前
|
网络协议 数据安全/隐私保护 网络虚拟化
【网络实验】/主机/路由器/交换机/网关/路由协议/RIP+OSPF/DHCP(下)
【网络实验】/主机/路由器/交换机/网关/路由协议/RIP+OSPF/DHCP(下)
60 0
|
1月前
|
移动开发 网络协议 测试技术
Mininet多数据中心网络拓扑流量带宽实验
Mininet多数据中心网络拓扑流量带宽实验
59 0
|
1月前
|
Kubernetes 容器
基于Ubuntu-22.04安装K8s-v1.28.2实验(三)数据卷挂载NFS(网络文件系统)
基于Ubuntu-22.04安装K8s-v1.28.2实验(三)数据卷挂载NFS(网络文件系统)
138 0
下一篇
无影云桌面