模式识别与图像处理课程实验二:基于UNet的目标检测网络(下)

简介: 模式识别与图像处理课程实验二:基于UNet的目标检测网络(下)

3.14、测试函数


# 测试函数
def test(device, test_dataloader):
    fcrn_encode.eval()
    fcrn_decode.eval()
#     Gen.eval()
    for batch_idx, (road, road_label, img_name)in enumerate(test_dataloader):
        road, _ = road.to(device), road_label.to(device)
        # z = torch.randn(road.shape[0], 1, IMAGE_SCALE, IMAGE_SCALE, device=device)
        # img_noise = torch.cat((road, z), dim=1)
        # fake_feature = Gen(img_noise)
        feature, x2, x3, x4  = fcrn_encode(road)
        det_road = fcrn_decode(feature, x2, x3, x4)
        label = det_road.detach().cpu()
        label = np.transpose(np.array(utils.make_grid(label, padding=0, nrow=1)), (1, 2, 0))
        # blur = cv2.GaussianBlur(label*255, (5, 5), 0)
        _, thresh = cv2.threshold(label*255, 200, 255, cv2.THRESH_BINARY)
        cv2.imwrite('./test/lab_dete_AVD/{}.png'.format(int(img_name[0])), thresh)
        print('testing...')
        print('{}/{}'.format(batch_idx, len(test_dataloader)))
    print('Done!')
# 文件的读取与保存
def iou(path_img, path_lab, epoch):
    img_name = os.listdir(path_img)
    img_name.sort(key=lambda x:int(x[:-4]))
    print(img_name)
    iou_list = []
    for i in range(len(img_name)):
        det = img_name[i]
        det = cv2.imread(path_img + '/' + det, 0)
        lab = img_name[i]
        lab = cv2.imread(path_lab + '/' + lab[:-4] + '.png', 0)
        lab = cv2.resize(lab, (opt.image_scale_w, opt.image_scale_h))
        count0, count1, a, count2 = 0, 0, 0, 0
        for j in range(det.shape[0]):
            for k in range(det.shape[1]):
                if det[j][k] != 0 and lab[j][k] != 0:
                    count0 += 1
                elif det[j][k] == 0 and lab[j][k] != 0:
                    count1 += 1
                elif det[j][k] != 0 and lab[j][k] == 0:
                    count2 += 1
                #iou = (count1 + count2)/(det.shape[0] * det.shape[1])
                iou = count0/(count1 + count0 + count2 + 0.0001)
        iou_list.append(iou)
        print(img_name[i], ':', iou)
    print('mean_iou:', sum(iou_list)/len(iou_list))
    with open('./munich_iou.txt',"a") as f:
        f.write("model_num" + " " + str(epoch) + " " + 'mean_iou:' + str(sum(iou_list)/len(iou_list)) + '\n')

3.15、主函数


# 主函数
if __name__ == '__main__':
    if opt.mode == 'train':
        num_epochs = opt.num_epochs
        for epoch in tqdm(range(num_epochs)):
            train(device, train_dataloader, epoch)
            Dis_scheduler.step()
            Gen_scheduler.step()
            encode_scheduler.step()
            decode_scheduler.step()
            if epoch % 50 == 0:
                now = time.strftime("%Y-%m-%d-%H_%M_%S",time.localtime(time.time()))
                torch.save(Dis.state_dict(), './model/Dis_{}'+ now +'munich.pkl'.format(opt.alpha))
                torch.save(Gen.state_dict(), './model/Gen_{}'+ now +'munich.pkl'.format(opt.alpha))
                torch.save(fcrn_decode.state_dict(), './model/fcrn_decode_{}'+ now +'munich.pkl'.format(opt.alpha))
                torch.save(fcrn_encode.state_dict(), './model/fcrn_encode_{}'+ now +'munich.pkl'.format(opt.alpha))
                print('testing...')
                test(device, test_dataloader)
                iou('./test/lab_dete_AVD', './munich/test/lab', epoch)
    if opt.mode == 'test':
        test(device, test_dataloader)
        iou('./test/lab_dete_AVD', './munich/test/lab', 'test')


四、 实验运行步骤与运行结果

4.1、 运行步骤

  • 1


364bfc89eca14062a5c0f38e8b67ed7e.png


  • 2


06b024c383dc4bd0baa4cb1d639b63cf.png


3


12038bf571874e5a9c4d9d8423815a3c.png


4


c124ba51b78645c88b8a0348e5377115.png


5


be9df6a2cf7c4802b8b1b9e094c6bac2.png


4.2、 运行的结果

1

fa5946bd93c647938b5199f6b7261067.png


2


ae5f895078264912849e9c980cff07b0.png


3


c36890a9a2364595a4f8d8f15f8d386f.png

4


d170b6cdb0c548a3b6edf485e1534490.png


5


af4dcee3eafd4981939a52d96a34aae4.png

-6


2ff0a988fe114567b8b0cc5b9f803c41.png

  • 7


eb9086280a2b47d9bf1ce6442761c0a9.png


8

61d96a40f7744650a6a5100b7fc0659f.png



五、 实验总结


  • 从运行结果可以看出,用Unet网络训练目标数据集,可以对数据集的道路目标实现准确的检测。
  • 从大量的数据集中进行测试,在CPU上运行,Unet网络测试数据用了将近10小时的训练时间。但是,得到的目标检测的结果是非常准确的。



31a3b3bafdd34abd8155ab4a5cb14a7f.jpg

相关文章
|
2天前
|
机器学习/深度学习 存储 监控
数据分享|Python卷积神经网络CNN身份识别图像处理在疫情防控下口罩识别、人脸识别
数据分享|Python卷积神经网络CNN身份识别图像处理在疫情防控下口罩识别、人脸识别
|
2天前
|
存储 算法 Windows
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例(下)
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例
|
2天前
|
算法 数据可视化 数据挖掘
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例(上)
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例
|
2天前
|
机器学习/深度学习 算法 计算机视觉
[YOLOv8/YOLOv7/YOLOv5系列算法改进NO.5]改进特征融合网络PANET为BIFPN(更新添加小目标检测层yaml)
本文介绍了改进YOLOv5以解决处理复杂背景时可能出现的错漏检问题。
111 5
|
2天前
|
存储 缓存 网络协议
【计网·湖科大·思科】实验二 计算机网络的寻址问题
【计网·湖科大·思科】实验二 计算机网络的寻址问题
4 0
|
2天前
|
运维 监控 安全
网络安全预习课程笔记(四到八节)
网络安全领域的岗位多样化,包括应急响应、代码审计、安全研究、工具编写、报告撰写、渗透测试和驻场服务等。其中,应急响应处理系统故障和安全事件,代码审计涉及源码漏洞查找,安全研究侧重漏洞挖掘,工具编写则要开发自动化工具,报告撰写需要良好的写作能力。渗透测试涵盖Web漏洞和内网渗透。岗位选择受公司、部门和领导的影响。此外,还可以参与CTF比赛或兼职安全事件挖掘。了解不同岗位职责和技能需求,如安全运维工程师需要熟悉Web安全技术、系统加固、安全产品和日志分析等。同时,渗透测试包括信息收集、威胁建模、漏洞分析、攻击实施和报告撰写等步骤。学习网络安全相关术语,如漏洞、木马、后门等,有助于深入理解和学习。
|
2天前
|
前端开发 数据挖掘 数据建模
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例(中)
课程视频|R语言bnlearn包:贝叶斯网络的构造及参数学习的原理和实例
|
2天前
|
JavaScript Java 测试技术
基于Java的网络类课程思政学习系统的设计与实现(源码+lw+部署文档+讲解等)
基于Java的网络类课程思政学习系统的设计与实现(源码+lw+部署文档+讲解等)
32 0
基于Java的网络类课程思政学习系统的设计与实现(源码+lw+部署文档+讲解等)
|
2天前
|
弹性计算 网络协议 关系型数据库
网络技术基础阿里云实验——企业级云上网络构建实践
实验地址:<https://developer.aliyun.com/adc/scenario/65e54c7876324bbe9e1fb18665719179> 本文档指导在阿里云上构建跨地域的网络环境,涉及杭州和北京两个地域。任务包括创建VPC、交换机、ECS实例,配置VPC对等连接,以及设置安全组和网络ACL规则以实现特定服务间的互访。例如,允许北京的研发服务器ECS-DEV访问杭州的文件服务器ECS-FS的SSH服务,ECS-FS访问ECS-WEB01的SSH服务,ECS-WEB01访问ECS-DB01的MySQL服务,并确保ECS-WEB03对外提供HTTP服务。
网络技术基础(19)——PPPoE实验
【3月更文挑战第5天】该文介绍了PPPoE拨号上网的模拟实验。通过运营商提供的PPPoverEthernet服务,设备可以动态获取内网地址并连接到Internet。实验包括服务器和客户端的配置:服务器设置地址池、认证账号和虚拟模板,并绑定到物理接口;客户端配置拨号规则、虚拟拨号口及内网网关,以实现永久在线连接。实验结果显示,客户端成功通过Dialer接口连接到服务器,实现了上网功能。

热门文章

最新文章