yolov7开源代码讲解--训练代码

简介: 笔记

以前看CNN训练代码的时候,往往代码比较易懂,基本很快就能知道各个模块功能,但到了后面很多出来的网络中,由于加入了大量的trick,导致很多人看不懂代码,代码下载以后无从下手。训练参数和利用yaml定义网络详细过程可以看我另外的文章,都有写清楚。


其实不管什么网络,训练部分大体都分几个部分:

本文主要会将上述几个部分代码列出来,致于其中的trick部分这里暂不解释【后续我会再写有关内容】,只是为了方便大家先了解训练过程。


注:这里并不是完整的代码,仅仅是对训练代码中几个重要的部分进行梳理、


1.网络的定义


加载权重一般分两种情况:1)有预权重;2)无预权重


1)有预权重情况下的网络定义;

if pretrained:  # 加载预权重
        ckpt = torch.load(weights, map_location=device)  # 加载模型
        # 模型的定义
        model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
        exclude = ['anchor'] if (opt.cfg or hyp.get('anchors'))else []  # exclude keys
        state_dict = ckpt['model'].float().state_dict()  # to FP32 获得预权重的权值
        state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude)  # intersect
        # 将权重加载到模型内
        model.load_state_dict(state_dict, strict=False)  # load

ckpt是加载的预权重【通过torch.load就能看出来】,weights是预权重路径,device用cpu还是gpu。


state_dict是获取的预权重的权值,通过state_dict()可以看出来。


intersect_dicts就是将预权重和模型本身默认的权值进行对比后再赋值给预权重【比如对比层的shape或keys】


model.load_state_dict就是将预权重加载到网络重。


2)无预权重


无预权重就可以直接使用模型默认权值训练,或者也可以自己去初始化一下。


model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create

2.数据集的处理与加载


数据集的处理的第一步当然是读取数据集路径。这个路径一般在yaml文件中,比如data/coco.yaml.路径文件是txt格式的【后面我会写一篇训练自己的数据集】


获得路径:这个data_dict就是我们读取的yaml文件


train_path = data_dict['train']  # 训练集路径,加载的txt
test_path = data_dict['val']  # 测试集路径,加载的txt

要想定位到数据集的处理,只需要找到两个函数,dataset和Dataloader,因为数据集的处理都需要继承torc提供的这两个函数。


dataset用来处理,dataloader用来加载,这两个函数以及处理数据集在我博客中都有写。


yolov7中是实现了一个create_dataloader函数进行数据集的处理【具体实现可以看我yolov7专栏中的数据集处理】


训练集处理:


# Trainloader  训练数据集的处理
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
                                            hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
                                            world_size=opt.world_size, workers=opt.workers,
                                            image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))

验证集处理:


testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt,  # testloader
                                       hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
                                       world_size=opt.world_size, workers=opt.workers,
                                       pad=0.5, prefix=colorstr('val: '))[0]

取出图像(后面要放入model中)


pbar = enumerate(dataloader)
pbar = tqdm(pbar, total=nb)  # progress bar
for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float() / 255.0  # uint8 to float32, 0-255 to 0.0-1.0


3.训练超参数的定义与初始化


v7中的超参数,比如学习率,优化器类型等是有个hyp的yaml定义的。


然后在代码中又建立了pg0,pg1,pg2列表用来存储网络中的参数【这些参数就是可导的,也就是我们要的最终网络权重】。


pg0, pg1, pg2 = [], [], []  # optimizer parameter groups,pg0放BN层weight,implicit,pg1放卷积weight,pg2放bias

for k, v in model.named_modules(): # =model.modules()
        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm2d):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
            pg1.append(v.weight)  # apply decay
        if hasattr(v, 'im'):
            if hasattr(v.im, 'implicit'):
                pg0.append(v.im.implicit)
            else:
                for iv in v.im:
                    pg0.append(iv.implicit)
        if hasattr(v, 'imc'):
            if hasattr(v.imc, 'implicit'):
                pg0.append(v.imc.implicit)
            else:
                for iv in v.imc:
                    pg0.append(iv.implicit)
        if hasattr(v, 'imb'):
            if hasattr(v.imb, 'implicit'):
                pg0.append(v.imb.implicit)
            else:
                for iv in v.imb:
                    pg0.append(iv.implicit)
        if hasattr(v, 'imo'):
            if hasattr(v.imo, 'implicit'):
                pg0.append(v.imo.implicit)
            else:
                for iv in v.imo:
                    pg0.append(iv.implicit)
        if hasattr(v, 'ia'):
            if hasattr(v.ia, 'implicit'):
                pg0.append(v.ia.implicit)
            else:
                for iv in v.ia:
                    pg0.append(iv.implicit)
        if hasattr(v, 'attn'):
            if hasattr(v.attn, 'logit_scale'):
                pg0.append(v.attn.logit_scale)
            if hasattr(v.attn, 'q_bias'):
                pg0.append(v.attn.q_bias)
            if hasattr(v.attn, 'v_bias'):
                pg0.append(v.attn.v_bias)
            if hasattr(v.attn, 'relative_position_bias_table'):
                pg0.append(v.attn.relative_position_bias_table)
        if hasattr(v, 'rbr_dense'):
            if hasattr(v.rbr_dense, 'weight_rbr_origin'):
                pg0.append(v.rbr_dense.weight_rbr_origin)
            if hasattr(v.rbr_dense, 'weight_rbr_avg_conv'):
                pg0.append(v.rbr_dense.weight_rbr_avg_conv)
            if hasattr(v.rbr_dense, 'weight_rbr_pfir_conv'):
                pg0.append(v.rbr_dense.weight_rbr_pfir_conv)
            if hasattr(v.rbr_dense, 'weight_rbr_1x1_kxk_idconv1'):
                pg0.append(v.rbr_dense.weight_rbr_1x1_kxk_idconv1)
            if hasattr(v.rbr_dense, 'weight_rbr_1x1_kxk_conv2'):
                pg0.append(v.rbr_dense.weight_rbr_1x1_kxk_conv2)
            if hasattr(v.rbr_dense, 'weight_rbr_gconv_dw'):
                pg0.append(v.rbr_dense.weight_rbr_gconv_dw)
            if hasattr(v.rbr_dense, 'weight_rbr_gconv_pw'):
                pg0.append(v.rbr_dense.weight_rbr_gconv_pw)
            if hasattr(v.rbr_dense, 'vector'):
                pg0.append(v.rbr_dense.vector)


调用优化器是要到 optim这个库:

from torch import optim

这里是可以选用Adam和SGD两种。


# 梯度优化方法Adam和SGD
    if opt.adam:
        optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

选择好优化器,我们也获得了网络中要求导的参数,接下来就是把参数放入优化器就行了。


optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']})  # add pg1 with weight_decay
optimizer.add_param_group({'params': pg2})  # add pg2 (biases)

4.损失函数的定义


网络有了,数据集有了,优化器有了,现在就还需要损失函数了。


compute_loss_ota = ComputeLossOTA(model)  # init loss class
compute_loss = ComputeLoss(model)  # init loss class

5.训练


训练阶段,我们需要定义训练的epoch,即利用for训练不停的将image喂入model,将获得的output和真实值之间建立损失函数进行训练。


训练阶段有个很明显的地方就是model.train()【看到这里就知道已经进入了训练阶段】,测试阶段就是model.eval()。


model.train()


for epoch in range(start_epoch, epochs):  # epoch
        model.train()
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        if rank in [-1, 0]:
        pbar = tqdm(pbar, total=nb)  # progress bar
        optimizer.zero_grad()
        for i, (imgs, targets, paths, _) in pbar:  # batch
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float() / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

pbar就是从dataloader中读取数据,并将imgs放入device后除以255.0【像素的归一化】


5.1.1前向传播

pred就是得到的output,然后与ground truth签求loss。


   

# Forward
            with amp.autocast(enabled=cuda):
                pred = model(imgs)  # forward
                loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs)  # loss scaled by batch_size

5.1.2反向传播

看到backward()这个地方就是梯度反向传播的过程

# Backward
            scaler.scale(loss).backward()

5.1.3梯度更新

看到update()和zero_grad()就是梯度更新的地方,每次更新后在进行清零操作。

# Optimize
            if ni % accumulate == 0:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()

5.1.4模型保存

模型的保存会用到torch.save(),所以只要看到这个地方就知道是模型训练好的模型。保存的ckpt就是我们的训练模型,只不过里面还保存了训练的epoch,best_fitness等。

                ckpt = {'epoch': epoch,
                        'best_fitness': best_fitness,
                        'training_results': results_file.read_text(),
                        'model': deepcopy(model.module if is_parallel(model) else model).half(),
                        'ema': deepcopy(ema.ema).half(),
                        'updates': ema.updates,
                        'optimizer': optimizer.state_dict(),
                        'wandb_id': None}
                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                if (best_fitness == fi) and (epoch >= 200):
                    torch.save(ckpt, wdir / 'best_{:03d}.pt'.format(epoch))
                if epoch == 0:
                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
                elif ((epoch+1) % 25) == 0:
                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
                elif epoch >= (epochs-5):
                    torch.save(ckpt, wdir / 'epoch_{:03d}.pt'.format(epoch))
                del ckpt


目录
相关文章
|
6月前
|
Python
【论文复现】针对yoloV5-L部分的YoloBody部分重构(Slim-neck by GSConv)
【论文复现】针对yoloV5-L部分的YoloBody部分重构(Slim-neck by GSConv)
173 0
【论文复现】针对yoloV5-L部分的YoloBody部分重构(Slim-neck by GSConv)
|
12天前
|
机器学习/深度学习 人工智能 计算机视觉
YOLOv11 正式发布!你需要知道什么? 另附:YOLOv8 与YOLOv11 各模型性能比较
YOLOv11是Ultralytics团队推出的最新版本,相比YOLOv10带来了多项改进。主要特点包括:模型架构优化、GPU训练加速、速度提升、参数减少以及更强的适应性和更多任务支持。YOLOv11支持目标检测、图像分割、姿态估计、旋转边界框和图像分类等多种任务,并提供不同尺寸的模型版本,以满足不同应用场景的需求。
YOLOv11 正式发布!你需要知道什么? 另附:YOLOv8 与YOLOv11 各模型性能比较
|
1月前
|
PyTorch 算法框架/工具 计算机视觉
目标检测实战(二):YoloV4-Tiny训练、测试、评估完整步骤
本文介绍了使用YOLOv4-Tiny进行目标检测的完整流程,包括模型介绍、代码下载、数据集处理、网络训练、预测和评估。
92 2
目标检测实战(二):YoloV4-Tiny训练、测试、评估完整步骤
|
1月前
|
计算机视觉
目标检测笔记(二):测试YOLOv5各模块的推理速度
这篇文章是关于如何测试YOLOv5中不同模块(如SPP和SPPF)的推理速度,并通过代码示例展示了如何进行性能分析。
78 3
|
1月前
|
计算机视觉 异构计算
目标检测实战(四):YOLOV4-Tiny 源码训练、测试、验证详细步骤
这篇文章详细介绍了使用YOLOv4-Tiny进行目标检测的实战步骤,包括下载源码和权重文件、配置编译环境、进行简单测试、训练VOC数据集、生成训练文件、准备训练、开始训练以及多GPU训练的步骤。文章还提供了相应的代码示例,帮助读者理解和实践YOLOv4-Tiny模型的训练和测试过程。
92 0
|
机器学习/深度学习 PyTorch 算法框架/工具
ResNet代码复现+超详细注释(PyTorch)
ResNet代码复现+超详细注释(PyTorch)
2128 1
|
6月前
|
并行计算 计算机视觉
YOLOv8太卷啦 | YOLOv8官方仓库正式支持RT-DETR训练、测试以及推理
YOLOv8太卷啦 | YOLOv8官方仓库正式支持RT-DETR训练、测试以及推理
467 0
|
PyTorch 算法框架/工具 机器学习/深度学习
GoogLeNet InceptionV3代码复现+超详细注释(PyTorch)
GoogLeNet InceptionV3代码复现+超详细注释(PyTorch)
391 0
|
PyTorch 算法框架/工具
GoogLeNet InceptionV1代码复现+超详细注释(PyTorch)
GoogLeNet InceptionV1代码复现+超详细注释(PyTorch)
332 0
|
计算机视觉
【YOLOV5-6.x讲解】YOLO5.0VS6.0版本对比+模型设计
【YOLOV5-6.x讲解】YOLO5.0VS6.0版本对比+模型设计
1071 0
【YOLOV5-6.x讲解】YOLO5.0VS6.0版本对比+模型设计