目标检测实战(三):YOLO-Nano训练、测试、验证详细步骤

简介: 本文介绍了YOLO-Nano在目标检测中的训练、测试及验证步骤。YOLO-Nano是一个轻量级目标检测模型,使用ShuffleNet-v2作为主干网络,结合FPN+PAN特征金字塔和NanoDet的检测头。文章详细说明了训练前的准备、源代码下载、数据集准备、参数调整、模型测试、FPS测试、VOC-map测试、模型训练、模型测试和验证等步骤,旨在帮助开发者高效实现目标检测任务。

训练前准备

包括代码、数据集(VOC或者COCO)、调参等等…

下载源代码

受NanoDet启发的新版YOLO-Nano
网络架构分析:主干网:shufflenetv2,特征金字塔采用FPN+PAN,head用的是NanoDet的head

优化模型方式—多尺度学习、余弦退火、warmup、高分辨率、mosaic、KM聚类
损失函数:ciou_loss
预测框筛选:DIoU_nms

下载VOC和COCO数据集

这里给出一个公共数据集下载的网址:点击
修改数据集路径(voc0712.py的26行)

调参

这里主要是调整训练和epoch,no_warm_up选择False代表要采用预热模型的方式。
修改:
- epoch:config.py 里 5-8行
训练时设置use_cuda为True
训练时设置主干网的模型(yolo_nano_1.0x/yolo_nano_0.5x)
训练时设置dataset的类型VOC/COCO
如果要使用tensorboard则修改一下这里在这里插入图片描述
遇到的错误修改:
在这里插入图片描述

测试现有模式

测试图片检测和FPS

  • trained_model—修改为文件夹下存在的模型

自己根据评估那个写了个可以选择测试多张图片检测情况和FPS的代码,修改mode方式就可以。

import time
from PIL import Image
import cv2
import numpy as np
import os
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from data import *
import numpy as np
import cv2
import tools
import time
from data.voc0712 import VOCAnnotationTransform

parser = argparse.ArgumentParser(description='YOLO-Nano Detection')
parser.add_argument('-v', '--version', default='yolo_nano_1.0x',
                    help='yolo_nano_0.5x, yolo_nano_1.0x.')
parser.add_argument('-d', '--dataset', default='voc',
                    help='voc, coco-val.')
parser.add_argument('-size', '--input_size', default=416, type=int,
                    help='input_size')
parser.add_argument('--trained_model',
                    default=r'weights/voc/yolo_nano_1.0x/yolo_nano_1.0x_67.23.pth',
                    type=str, help='Trained state_dict file path to open')
parser.add_argument('--conf_thresh', default=0.1, type=float,
                    help='Confidence threshold')
parser.add_argument('--nms_thresh', default=0.50, type=float,
                    help='NMS threshold')
parser.add_argument('--visual_threshold', default=0.3, type=float,
                    help='Final confidence threshold')
parser.add_argument('--cuda', action='store_true', default=True,
                    help='use cuda.')
parser.add_argument('--diou_nms', action='store_true', default=False,
                    help='use diou nms.')

args = parser.parse_args()

def vis(img, bboxes, scores, cls_inds, thresh, class_colors, class_names, class_indexs=None, dataset='voc'):
    if dataset == 'voc':
        for i, box in enumerate(bboxes):
            cls_indx = cls_inds[i]
            xmin, ymin, xmax, ymax = box
            if scores[i] > thresh:
                cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), class_colors[int(cls_indx)], 1)
                cv2.rectangle(img, (int(xmin), int(abs(ymin) - 20)), (int(xmax), int(ymin)),
                              class_colors[int(cls_indx)], -1)
                mess = '%s' % (class_names[int(cls_indx)])
                cv2.putText(img, mess, (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

    elif dataset == 'coco-val' and class_indexs is not None:
        for i, box in enumerate(bboxes):
            cls_indx = cls_inds[i]
            xmin, ymin, xmax, ymax = box
            if scores[i] > thresh:
                cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), class_colors[int(cls_indx)], 1)
                cv2.rectangle(img, (int(xmin), int(abs(ymin) - 20)), (int(xmax), int(ymin)),
                              class_colors[int(cls_indx)], -1)
                cls_id = class_indexs[int(cls_indx)]
                cls_name = class_names[cls_id]
                # mess = '%s: %.3f' % (cls_name, scores[i])
                mess = '%s' % (cls_name)
                cv2.putText(img, mess, (int(xmin), int(ymin - 5)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)

    return img

def test(net, device, testset, transform, thresh, class_colors=None, class_names=None, class_indexs=None,
         dataset='voc',save=False,test_num=100,mode=''):
    num_images = len(testset)
    test_time,idx=[],1
    for index in range(num_images):
        print('Testing image {:d}/{:d}....'.format(index + 1, num_images))
        img, _ = testset.pull_image(index)
        img_tensor, _, h, w, offset, scale = testset.pull_item(index)

        # to tensor
        x = img_tensor.unsqueeze(0).to(device)

        t0 = time.time()
        # forward
        bboxes, scores, cls_inds = net(x)
        print("detection time used ", time.time() - t0, "s")
        if idx!=1:
            test_time.append(float(time.time() - t0))
        # scale each detection back up to the image
        max_line = max(h, w)
        # map the boxes to input image with zero padding
        bboxes *= max_line
        # map to the image without zero padding
        bboxes -= (offset * max_line)

        img_processed = vis(img, bboxes, scores, cls_inds, thresh, class_colors, class_names, class_indexs, dataset)
        if mode=='fps':
            if idx == test_num:
                break
            idx += 1
        else:
            cv2.imshow('detection', img_processed)
            cv2.waitKey(0)
            if save:
                print('Saving the' + str(index) + '-th image ...')
                save_path=r'D:\pycharm_Z\YOLO-Nano\img_files\save_detection_pic/'
                os.makedirs(os.path.dirname(save_path),exist_ok=True)
                cv2.imwrite( save_path+ str(index).zfill(6) +'.jpg', img_processed)
    return test_time

if __name__ == '__main__':
    # get device
    if args.cuda:
        print('use cuda')
        cudnn.benchmark = True
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    input_size = [args.input_size, args.input_size]

    # dataset
    if args.dataset == 'voc':
        print('test on voc ...')
        class_names = VOC_CLASSES
        class_indexs = None
        num_classes = 20
        anchor_size = MULTI_ANCHOR_SIZE
        dataset = VOCDetection(root=VOC_ROOT,
                               img_size=None,
                               image_sets=[('2007', 'test')],
                               transform=BaseTransform(input_size))

    elif args.dataset == 'coco-val':
        print('test on coco-val ...')
        class_names = coco_class_labels
        class_indexs = coco_class_index
        num_classes = 80
        anchor_size = MULTI_ANCHOR_SIZE_COCO
        dataset = COCODataset(
            data_dir=coco_root,
            json_file='instances_val2017.json',
            name='val2017',
            transform=BaseTransform(input_size),
            img_size=input_size[0])

    class_colors = [(np.random.randint(255), np.random.randint(255), np.random.randint(255)) for _ in
                    range(num_classes)]

    # build model
    if args.version == 'yolo_nano_0.5x':
        from models.yolo_nano import YOLONano

        backbone = '0.5x'
        net = YOLONano(device, input_size=input_size, num_classes=num_classes, anchor_size=anchor_size,
                       backbone=backbone)
        print('Let us train yolo_nano_0.5x ......')

    if args.version == 'yolo_nano_1.0x':
        from models.yolo_nano import YOLONano

        backbone = '1.0x'
        net = YOLONano(device, input_size=input_size, num_classes=num_classes, anchor_size=anchor_size,
                       backbone=backbone)
        print('Let us train yolo_nano_1.0x ......')

    else:
        print('Unknown version !!!')
        exit()

    net.load_state_dict(torch.load(args.trained_model, map_location=device))
    net.to(device).eval()

    print('Finished loading model!')
    #-------------------------------------------------------------------------#
    #   mode用于指定测试的模式:
    #   'predict'表示 多张图片预测和保存
    #   'fps'表示测试fps
    #-------------------------------------------------------------------------#
    mode = "predict"

    if mode == "predict":
        # evaluation
        test(net=net,
             device=device,
             testset=dataset,
             transform=BaseTransform(input_size),
             thresh=args.visual_threshold,
             class_colors=class_colors,
             class_names=class_names,
             class_indexs=class_indexs,
             dataset=args.dataset,
             save=True,
             mode="predict"
             )

    elif mode == "fps":
        # evaluation
        test_num=10
        time_all=test(net=net,
             device=device,
             testset=dataset,
             transform=BaseTransform(input_size),
             thresh=args.visual_threshold,
             class_colors=class_colors,
             class_names=class_names,
             class_indexs=class_indexs,
             dataset=args.dataset,
             test_num=test_num,
             mode="fps"
             )
        time_avg=sum(time_all)/len(time_all)
        print('the whole time:{}'.format(time_avg))
        print('fps:{}'.format(1/time_avg))

测试VOC-map

找到eval.py并修改测试模型路径和指定数据集路径
在这里插入图片描述
在这里插入图片描述
然后就可以运行eval.py文件。

!!!!如果遇到错误R = [obj for obj in recs[imagename] if obj[‘name’] == classname] KeyError: ‘007765’
解决办法------训练前需要将cache中的pki文件(找到voc_eval/test)以及VOCdevkit2007中annotations_cache的缓存删掉(在你的数据集里面会新建这个文件)

开始训练模型

训练文件脚本:train.py
修改完上述的地方应该就可以直接运行了(环境没问题的情况下)
在这里插入图片描述
每10个epoch会测试一下map
如果使用了tensorboard,可以在终端输入
tensorboard --logdir=D:\pycharm_Z\YOLO-Nano\log\voc\yolo_nano_1.0x\2021-09-13-13-31-02
出现一个网址进去就是
在这里插入图片描述
一个为分类loss,一个为回归loss,一个为多目标loss

测试模型

测试模型就可以使用我上面给出的代码,检测一下fps和图片检测情况
在这里插入图片描述
在这里插入图片描述

验证模型

如果是VOC数据集
验证模型的指标为Map,以及各类AP情况,通过eval.py,将模型改为你训练后的模型即可
如果是COCO数据集同理

目录
相关文章
|
2月前
|
数据采集 JSON JavaScript
Cypress 插件实战:让测试更稳定,不再“偶尔掉链子”
本文分享如何通过自定义Cypress插件解决测试不稳定的痛点。插件可实现智能等待、数据预处理等能力,替代传统硬性等待,有效减少偶发性失败,提升测试效率和可维护性。文内包含具体实现方法与最佳实践。
|
3月前
|
存储 关系型数据库 测试技术
玩转n8n测试自动化:核心节点详解与测试实战指南
n8n中节点是自动化测试的核心,涵盖触发器、数据操作、逻辑控制和工具节点。通过组合节点,测试工程师可构建高效、智能的测试流程,提升测试自动化能力。
|
2月前
|
人工智能 自然语言处理 JavaScript
Playwright MCP在UI回归测试中的实战:构建AI自主测试智能体
Playwright MCP结合AI智能体,革新UI回归测试:通过自然语言驱动浏览器操作,降低脚本编写门槛,提升测试效率与覆盖范围。借助快照解析、智能定位与Jira等工具集成,实现从需求描述到自动化执行的闭环,推动测试迈向智能化、民主化新阶段。
|
3月前
|
人工智能 数据可视化 测试技术
AI 时代 API 自动化测试实战:Postman 断言的核心技巧与实战应用
AI 时代 API 自动化测试实战:Postman 断言的核心技巧与实战应用
486 11
|
3月前
|
测试技术 UED 开发者
性能测试报告-用于项目的性能验证、性能调优、发现性能缺陷等应用场景
性能测试报告用于评估系统性能、稳定性和安全性,涵盖测试环境、方法、指标分析及缺陷优化建议,是保障软件质量与用户体验的关键文档。
|
4月前
|
算法 测试技术 API
从自学到实战:一位测试工程师的成长之路
在技术快速发展的今天,自动化测试已成为提升职场竞争力的关键技能。本文讲述了一位测试工程师从自学到实战的成长之路,分享他在学习UI、APP和API自动化过程中遇到的挑战,以及如何通过实际项目磨炼技术、突破瓶颈。他从最初自学的迷茫,到实战中发现问题、解决问题,再到得到导师指导,逐步掌握测试开发的核心思维,并向测试平台建设方向迈进。文章总结了他从理论到实践、从执行到思考的转变经验,强调了实战、导师指导和技术服务于业务的重要性。最后,邀请读者分享自己的技术突破故事,共同交流成长。
|
11月前
|
数据可视化 前端开发 测试技术
接口测试新选择:Postman替代方案全解析
在软件开发中,接口测试工具至关重要。Postman长期占据主导地位,但随着国产工具的崛起,越来越多开发者转向更适合中国市场的替代方案——Apifox。它不仅支持中英文切换、完全免费不限人数,还具备强大的可视化操作、自动生成文档和API调试功能,极大简化了开发流程。
|
6月前
|
Java 测试技术 容器
Jmeter工具使用:HTTP接口性能测试实战
希望这篇文章能够帮助你初步理解如何使用JMeter进行HTTP接口性能测试,有兴趣的话,你可以研究更多关于JMeter的内容。记住,只有理解并掌握了这些工具,你才能充分利用它们发挥其应有的价值。+
1000 23
|
8月前
|
SQL 安全 测试技术
2025接口测试全攻略:高并发、安全防护与六大工具实战指南
本文探讨高并发稳定性验证、安全防护实战及六大工具(Postman、RunnerGo、Apipost、JMeter、SoapUI、Fiddler)选型指南,助力构建未来接口测试体系。接口测试旨在验证数据传输、参数合法性、错误处理能力及性能安全性,其重要性体现在早期发现问题、保障系统稳定和支撑持续集成。常用方法包括功能、性能、安全性及兼容性测试,典型场景涵盖前后端分离开发、第三方服务集成与数据一致性检查。选择合适的工具需综合考虑需求与团队协作等因素。
1179 24
|
8月前
|
SQL 测试技术
除了postman还有什么接口测试工具
最好还是使用国内的接口测试软件,其实国内替换postman的软件有很多,这里我推荐使用yunedit-post这款接口测试工具来代替postman,因为它除了接口测试功能外,在动态参数的支持、后置处理执行sql语句等支持方面做得比较好。而且还有接口分享功能,可以生成接口文档给团队在线浏览。
342 2