MMdetection框架速成系列 第02部分:整体算法流程+模型搭建流程+detection训练与测试核心组件+训练部分与测试部分的核心算法

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,5000CU*H 3个月
模型训练 PAI-DLC,5000CU*H 3个月
简介: 众所周知,目标检测算法比较复杂,细节比较多,难以复现,而我们推出的 MMDetection 开源框架则希望解决上述问题。目前 MMdetection 已经复现了大部分主流和前沿模型,例如 Faster R-CNN 系列、Mask R-CNN 系列、YOLO 系列和比较新的 DETR 等等,模型库非常丰富,star 接近 13k,在学术研究和工业落地中应用非常广泛。

🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗


MMdetection框架速成系列


MMdetection框架速成系列 第01部分:

https://v9999.blog.csdn.net/article/details/128486362

MMdetection框架速成系列 第02部分:

https://v9999.blog.csdn.net/article/details/128486548

MMdetection框架速成系列 第03部分:

https://v9999.blog.csdn.net/article/details/129294753

🚗🚗🚗🚗🚗🚗🚗正文开始🚗🚗🚗🚗🚗🚗🚗


本系列解读主要分享 MMDetection 中已经复现的主流目标检测模型。


众所周知,目标检测算法比较复杂,细节比较多,难以复现,而我们推出的 MMDetection 开源框架则希望解决上述问题。目前 MMdetection 已经复现了大部分主流和前沿模型,例如 Faster R-CNN 系列、Mask R-CNN 系列、YOLO 系列和比较新的 DETR 等等,模型库非常丰富,star 接近 13k,在学术研究和工业落地中应用非常广泛。


任何一个目标检测算法都可以分成 n 个核心组件,组件和组件之间是隔离的,方便复用和设计。


当面对一个新算法时候我们可以先分析其主要是改进哪几个核心组件,然后就可以高效的掌握该算法。


另外还有一些重要的模块没有分析,特别是 dataset、dataloader 和分布式训练相关的检测代码。


71d3528f4c524b2d80cc75ccc8d68130.png


1 MMDetection是什么


MMDetection是OpenMMLab家族中的一员,主要负责2D目标检测领域(比如MMDetection3D则负责3D目标检测)。首先我们需要知道为什么会出现MMDetection这个框架。当前目标检测算法众多,方法复杂,细节较多,个人复现起来难度很大,而且由于缺少共享平台和统一规范,就算有人成功实现了某一个算法,也很难被其他人复用。


该模型使用统一的代码规范复现了当前大部分主流和前沿的模型,比如Faster R-CNN系列、YOLO系列,以及较新的DETR等(如下图所示),并提供了预训练模型。其他人只需要遵循这个规范,就能直接“白嫖”,不需要自己再重新实现一遍,而这个规范就MMDetection。


在丰富模型的基础上,MMDetection还支持自定义的扩展,可以在已有模型上进行修改,也可以自己从头搭建一个全新的模型,基本可以满足学术研究和工业落地的需求。


33f7abce0c3c4fa3a27eccc5ee031c2d.png


1.1 模型分类


1509e6e44900425db2a3225147d7293c.png


注意上面仅仅写了几个典型算法而已,简单来说目标检测算法可以按照 3 个维度划分:


  • 按照 stage 个数划分,常规是 one-stage 和 two-stage,但是实际上界限不是特别清晰,例如带 refine 阶段的算法 RepPoints,实际上可以认为是1.5 stage 算法,而 Cascade R-CNN 可以认为是多阶段算法,为了简单,上面图示没有划分如此细致


  • 按照是否需要预定义 anchor 划分,常规是 anchor-based 和 anchor-free,当然也有些算法是两者混合的


  • 按照是否采用了 transformer 结构划分,目前基于 transformer 结构的目标检测算法发展迅速,也引起了极大的关注,所以这里特意增加了这个类别的划分


不管哪种划分方式,其实都可以分成若干固定模块,然后通过模块堆叠来构建整个检测算法体系。


2. 整体算法流程


所有的目标检测算法都可以按照训练和测试流程抽象成若干个模块,对于初学者来说只要理解各个模块的输入输出以及实现的功能即可。这个流程也对应框架的代码构建流程,所以理解这副图很重要。


34132e058ec74dd29d85884516f5b53c.png


上述流程对应 MMDetection 代码构建流程,理解每个组件的作用不仅仅对阅读算法源码有帮助,而且还能够快速理解新提出算法对应的改进部分。下面对每个模块进行详细解读。


3 detection训练核心组件


训练部分一般包括 9 个核心组件,总体流程是:


①任何一个 batch 的图片先输入到 backbone 中进行特征提取,典型的骨干网络是 ResNet


②输出的单尺度或者多尺度特征图输入到 neck 模块中进行特征融合或者增强,典型的 neck 是 FPN


③上述多尺度特征最终输入到 head 部分,一般都会包括分类和回归分支输出


④在整个网络构建阶段都可以引入一些即插即用增强算子来增加提取提取能力,典型的例如 SPP、DCN 等等


⑤目标检测 head 输出一般是特征图,对于分类任务存在严重的正负样本不平衡,可以通过正负样本属性分配和采样控制


⑥为了方便收敛和平衡多分支,一般都会对 GT-bbox 进行编码


⑦最后一步是计算分类和回归 loss,进行训练


⑧在训练过程中也包括非常多的 trick,例如优化器选择等,参数调节也非常关键


624167aead404d97a56782798e1fb2f1.png


注意上述 9 个组件不是每个算法都需要的,下面详细分析。


3.1 Backbone


b91444496e5c41dcabd09bd4de34683a.png


backbone 作用主要是特征提取。目前 MMDetection 中已经集成了大部分骨架网络,具体见文件:

mmdet/models/backbones,V2.7 已经实现的骨架如下:


__all__ = [
    'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
    'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
    'ResNeSt', 'TridentResNet'
]


最常用的是 ResNet 系列、ResNetV1d 系列和 Res2Net 系列。若需要对backbone 进行扩展,继承上述网络的情况下,通过注册器机制注册使用。


# 典型用法为
# 骨架的预训练权重路径
pretrained='torchvision://resnet50',
backbone=dict(
    type='ResNet', # 骨架类名,后面的参数都是该类的初始化参数
    depth=50,
    num_stages=4,
    out_indices=(0, 1, 2, 3),
    frozen_stages=1,
    norm_cfg=dict(type='BN', requires_grad=True), 
    norm_eval=True,
    style='pytorch'),


在mmdetection中,可以通过 MMCV 中的注册器机制,可以通过 dict 形式的配置来实例化任何已经注册的类,非常方便和灵活。


3.2 Neck


feebd465b59d47ffa56f73b29f2d86ba.png


neck 部位: backbone 和 head 的连接层,对 backbone 的特征进行高效融合和增强,能够对输入的单尺度或者多尺度特征进行融合、增强输出等。具体见文件mmdet/models/necks,其V2.7 已经实现的 neck 如下:


__all__ = [
    'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
    'NASFCOS_FPN', 'RFP', 'YOLOV3Neck'
]


最常用的应该是 FPN,一个典型用法是:


neck=dict(
    type='FPN',
    in_channels=[256, 512, 1024, 2048], # 骨架多尺度特征图输出通道
    out_channels=256, # 增强后通道输出
    num_outs=5), # 输出num_outs个多尺度特征图


3.3 Head


a2a2f855baa84ec3acd0d34f632f25e7.png


目标检测算法输出一般包括分类和框坐标回归两个分支,不同算法 head 模块复杂程度不一样,灵活度比较高。在网络构建方面,理解目标检测算法主要是要理解 head 模块。


MMDetection 中 head 模块又划分为 two-stage 所需的 RoIHead 和 one-stage 所需的 DenseHead,也就是说所有的 one-stage 算法的 head 模块都在mmdet/models/dense_heads中,而 two-stage 算法还包括额外的mmdet/models/roi_heads。


目前 V2.7 中已经实现的 dense_heads 包括:


__all__ = [
    'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
    'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead',
    'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
    'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
    'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
    'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
    'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'TransformerHead'
]


几乎每个算法都包括一个独立的 head,而 roi_heads 比较杂,就不列出了。


需要注意的是:two-stage 或者 mutli-stage 算法,会额外包括一个区域提取器 roi extractor,用于将不同大小的 RoI 特征图统一成相同大小。


虽然 head 部分的网络构建比较简单,但是由于正负样本属性定义、正负样本采样和 bbox 编解码模块都在 head 模块中进行组合调用,故 MMDetection 中最复杂的模块就是 head。在最后的整体流程部分会对该模块进行详细分析。


3.4 Enhance


d95f702553eb442397cccc2ac0150c5c.png


enhance 是即插即用、能够对特征进行增强的模块,其具体代码可以通过 dict 形式注册到 backbone、neck 和 head 中,非常方便(目前还不完善)。


常用的 enhance 模块是 SPP、ASPP、RFB、Dropout、Dropblock、DCN 和各种注意力模块 SeNet、Non_Local、CBA 等。


目前 MMDetection 中部分模块支持 enhance 的接入,例如 ResNet 骨架中的 plugins,这个部分的解读放在具体算法模块中讲解。


3.5 BBox Assigner


正负样本属性分配模块作用是进行正负样本定义或者正负样本分配(可能也包括忽略样本定义),正样本就是常说的前景样本(可以是任何类别),负样本就是背景样本。


因为目标检测是一个同时进行分类和回归的问题,对于分类场景必然需要确定正负样本,否则无法训练。该模块至关重要,不同的正负样本分配策略会带来显著的性能差异,目前大部分目标检测算法都会对这个部分进行改进,至关重要。


典型的分配策略如下:


365ccea3b70246a58aff0aa035a64266.png


对应的代码在mmdet/core/bbox/assigners中,V2.7 主要包括:


__all__ = [
    'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 
    'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
    'HungarianAssigner'
]


3.9 BBox Sampler


在确定每个样本的正负属性后,可能还需要进行样本平衡操作。


本模块作用是对前面定义的正负样本不平衡进行采样,力争克服该问题。一般在目标检测中 gt bbox 都是非常少的,所以正负样本比是远远小于 1 的。


而基于机器学习观点:在数据极度不平衡情况下进行分类会出现预测倾向于样本多的类别,出现过拟合,为了克服该问题,适当的正负样本采样策略是非常必要的,一些典型采样策略如下:


9313e467e30c492b939a7bdf904d3403.png


__all__ = [
    'BaseSampler', 'PseudoSampler', 'RandomSampler',
    'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
    'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler'
]


3.10 BBox Encoder


为了更好的收敛和平衡多个 loss,具体解决办法非常多,而 bbox 编解码策略也算其中一个,bbox 编码阶段对应的是对正样本的 gt bbox 采用某种编码变换(反操作就是 bbox 解码),最简单的编码是对 gt bbox 除以图片宽高进行归一化以平衡分类和回归分支,一些典型的编解码策略如下:


dbc333f22a0c4898a7e756b6642a9465.png


对应的代码在mmdet/core/bbox/coder中,V2.7 主要包括:


__all__ = [
    'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
    'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
    'BucketingBBoxCoder'
]


3.11 Loss


Loss 通常都分为分类和回归 loss,其对网络 head 输出的预测值和 bbox encoder 得到的 targets 进行梯度下降迭代训练。


Loss 的设计也是各大算法重点改进对象,常用的 loss 如下:


d640556d27dd4953809f226ef4222ec7.png


对应的代码在mmdet/models/losses中,V2.7 主要包括:


__all__ = [
    'cross_entropy', 'binary_cross_entropy',
    'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
    'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss',
    'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
    'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', 'GHMC',
    'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
    'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
    'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
    'VarifocalLoss'
]


可以看出 MMDetection 中已经实现了非常多的 loss,可以直接使用。


3.12 Training tricks


训练技巧非常多,常说的调参很大一部分工作都是在设置这部分超参。这部分内容比较杂乱,很难做到完全统一,目前主流的 tricks 如下所示:


ea077d0a3b9544c18f978d9571793092.png


MMDetection 目前这部分还会继续完善,也欢迎大家一起贡献。


4 detection测试核心组件


测试核心组件和训练非常类似,但是简单很多,除了必备的网络构建部分外( backbone、neck、head 和 enhance ),不需要正负样本定义、正负样本采样和 loss 计算三个最难的部分,但是其额外需要一个 bbox 后处理模块和测试 trick。


相较于训练流程,测试时只有模型的前向推理过程,因此不需要正负样本分配、平衡、计算loss等操作,流程会更简单一些。


下表是测试流程特有的模块:


f06b5b2e9aa1462c9500b289d3591721.png


4.1 BBox Decoder


训练时候进行了编码,那么对应的测试环节需要进行解码。根据编码的不同,解码也是不同的。举个简单例子:假设训练时候对宽高是直接除以图片宽高进行归一化的,那么解码过程也仅仅需要乘以图片宽高即可。


其代码和 bbox encoder 放在一起,在mmdet/core/bbox/coder中。


4.2 BBox PostProcess


在得到原图尺度 bbox 后,由于可能会出现重叠 bbox 现象,故一般都需要进行后处理,最常用的后处理就是非极大值抑制以及其变种。


其对应的文件在mmdet/core/post_processing中,V2.7 主要包括:


__all__ = [
    'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
    'merge_aug_scores', 'merge_aug_masks', 'fast_nms'
]


4.3 Testing tricks


为了提高检测性能,测试阶段也会采用 trick。这个阶段的 tricks 也非常多,难以完全统一,最典型的是多尺度测试以及各种模型集成手段,典型配置如下:


dict(
    type='MultiScaleFlipAug',
    img_scale=(1333, 800),
    flip=True,
    transforms=[
        dict(type='Resize', keep_ratio=True),
        dict(type='RandomFlip'),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size_divisor=32),
        dict(type='ImageToTensor', keys=['img']),
        dict(type='Collect', keys=['img']),
    ])

f341614e538e446fbd1752380a7879dd.png



4.4 训练测试算法流程


在分析完每个训练流程的各个核心组件后,为了方便大家理解整个算法构建,下面分析 MMDetection 是如何组合各个组件进行训练的,这里以 one-stage 检测器为例,two-stage 也比较类似。


class SingleStageDetector(---):
   def __init__(...):
        # 构建骨架、neck和head
        self.backbone = build_backbone(backbone)
        if neck is not None:
            self.neck = build_neck(neck)
        self.bbox_head = build_head(bbox_head)
  def forward_train(---): 
        # 先运行backbone+neck进行特征提取
        x = self.extract_feat(img)
        # 对head进行forward train,输出loss
        losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)
        return losses
  def simple_test(---):
        # 先运行backbone+neck进行特征提取
        x = self.extract_feat(img)
        # head输出预测特征图
        outs = self.bbox_head(x)
        # bbox解码和还原
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
        # 重组结果返回
        bbox_results = [
            bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
            for det_bboxes, det_labels in bbox_list
        ]
        return bbox_results


以上就是整个检测器算法训练和测试最简逻辑,可以发现训练部分最核心的就是bbox_head.forward_train,测试部分最核心的是bbox_head.get_bboxes,下面单独简要分析。


5 训练部分与测试部分的两个核心算法


5.1 训练部分 bbox_head.forward_train


def forward_train(...):
    # 调用每个head自身的forward方法
    outs = self(x)
    if gt_labels is None:
        loss_inputs = outs + (gt_bboxes, img_metas)
    else:
        loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
    # 计算每个head自身的loss方法
    losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
    # 返回
    return losses


对于不同的 head,虽然 forward 内容不一样,但是依然可以抽象为: outs = self(x)


def forward(self, feats):
   # 多尺度特征图,一个一个迭代进行forward_single
   return multi_apply(self.forward_single, feats)
def forward_single(self, x):
   # 运行各个head独特的head forward方法,得到预测图
   ....
   return cls_score, bbox_pred...


而对于不同的 head,其 loss 计算部分也比较复杂,可以简单抽象为:losses = self.loss(...)


def loss(...):
    # 1 生成anchor-base需要的anchor或者anchor-free需要的points
    # 2 利用gt bbox对特征图或者anchor计算其正负和忽略样本属性
    # 3 进行正负样本采样
    # 4 对gt bbox进行bbox编码
    # 5 loss计算,并返回
    return dict(loss_cls=losses_cls, loss_bbox=losses_bbox,...)


5.2 测试部分 bbox_head.get_bboxes


get_bboxes函数更加简单


def get_bboxes(...):
   # 1 生成anchor-base需要的anchor或者anchor-free需要的points
   # 2 遍历每个输出层,遍历batch内部的每张图片,对每张图片先提取指定个数的预测结果,缓解后面后处理压力;对保留的位置进行bbox解码和还原到原图尺度
   # 3 统一nms后处理
   return det_bboxes, det_labels...


6. 算法搭建流程


以Pytorch训练流程为例,需要编写数据读取、数据预处理、数据增强、算法模型、loss函数、训练策略的代码,最后将其整合进train()函数中开始训练。


MMDetection已经实现上述步骤中的绝大部分方法,因此只需要调用现成的函数即可,具体是在Config文件中配置好相应方法的参数,并将Config文件传给MMDetection自带的train()函数,然后框架就会解析Config文件,自动调用配置好的方法,完成训练流程。


所以在MMDetection上搭建一个算法,要做的事情只有3件:准备数据集、编写Config文件、调用框架自带的train.py开始训练。


0f530cc947c54e0786a49de4d009df86.png


先从MMDetection自带的RetinaNet开始,在COCO数据集上打通训练和测试流程。


6.1 准备数据集


MMDetection已经实现COCO数据集的处理,这里就直接使用COCO 2014数据集。下载好的数据集目录结构如下图所示,annotations文件夹中以json文件格式存放了标注数据,其中目标框的标注信息在instances文件中:


3da5095e3c52498daa1fa63077e88a4b.png


3.2 编写Config文件


RetinaNet的配置文件位于MMDetection源码的./configs/retinanet路径下,打开目录会发现里面有很多Config文件,文件命名规则遵循:


{model}_[model setting]_{backbone}_{neck}_[norm setting]_[misc]_[gpu x batch_per_gpu]_{schedule}_{dataset}


其中大括号表示必选,大括号表示可选。比如我们接下来要使用的配置文件retinanet_r50_fpn_1x_coco.py含义就是:模型名称是RetinaNet,主干是ResNet50,Neck是FPN,训练12个Epoch(1个x是12,2个x就是24),使用COCO数据集。


716f3f5c7f8548cea979717cd4b0e462.png


但当打开配置文件retinanet_r50_fpn_1x_coco.py时,发现里面只有几行代码:


# 原因是MMDetection中配置文件是通过继承 + 修改的方式完成用户自定义配置文件的。
# _ base_ = list()表示需要继承的配置文件,然后通过重写的方式完成对应属性的修改。
_base_ = [
    '../_base_/models/retinanet_r50_fpn.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)


如果想要查看完整的配置文件信息,不需要依据_base_一级一级往上查找,可以通过官方给定的工具./tools/misc/print_config.py来打印配置文件:


# 在 ./tools/misc/print_config.py 
python ./tools/misc/print_config.py ./configs/retinanet/retinanet_r50_fpn_1x_coco.py


然后就可以看到retinanet_r50_fpn_1x_coco.py对应的完整的配置文件内容,第二节提到的各个模块都可以在配置文件中找到对应的定义。


配置文件由一串字典dict和变量的定义组成,经由Config.fromfile(filepath)函数加载后会返回一个Config类型的变量(MMCV的一个数据结构),然后MMDetection框架就能根据这个Config调用相关的build_detector()方法构建对应的模块。


build_detector()方法:首先会根据字典中的type找到对应的类(Class),这个类的类名就是type字符串的值,且这个类一定是事先注册(Registry) 好的,MMDetection能够根据type值查询到具体的类,否则就会报错。比如在下面配置文件中,model的type值为RetinaNet,我们可以在./mmdet/models/detectors/retinanet.py中找到定义。


@DETECTORS.register_module()    # 表示这个类已经注册
class RetinaNet(SingleStageDetector):
    """Implementation of `RetinaNet <https://arxiv.org/abs/1708.02002>`_"""
    def __init__(self,
                 backbone,
                 neck,
                 bbox_head,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
        super(RetinaNet, self).__init__(backbone, neck, bbox_head, train_cfg,
                                        test_cfg, pretrained, init_cfg)


RetinaNet类构造函数的参数刚好和配置文件中type='RetinaNet'的字典的其他键值对应。所以build_detector()函数的作用就是根据dict中的type找到对应的类,然后使用dict中传入的参数来对类进行初始化操作,并返回这个类的句柄。


# 下面两行调用是等价的
model = build_detector(Config{type='RetinaNet', backbone=xxx, neck=xxx, bbox_head=xxx})
model = RetinaNet(backbone=xxx, neck=xxx, bbox_head=xxx)


然后配置文件中的dict是可以嵌套的,比如说model的backbone属性是type='ResNet'一个字典,同理我们也可以在./mmdet/models/backbones/resnet.py中找到ResNet类的定义,并且字典的键值和构造函数匹配:


@BACKBONES.register_module()
class ResNet(BaseModule):
    """ResNet backbone."""
    def __init__(self,
                 depth,
                 in_channels=3,
                 stem_channels=None,
                 base_channels=64,
                 num_stages=4,
                 strides=(1, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 out_indices=(0, 1, 2, 3),
                 style='pytorch',
                 deep_stem=False,
                 avg_down=False,
                 frozen_stages=-1,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN', requires_grad=True),
                 norm_eval=True,
                 dcn=None,
                 stage_with_dcn=(False, False, False, False),
                 plugins=None,
                 with_cp=False,
                 zero_init_residual=True,
                 pretrained=None,
                 init_cfg=None):
        super(ResNet, self).__init__(init_cfg)
        self.zero_init_residual = zero_init_residual
        if depth not in self.arch_settings:
            raise KeyError(f'invalid depth {depth} for resnet')


下面是retinanet_r50_fpn_1x_coco.py完整的配置文件信息


Config:
# 1. 模型配置
model = dict(
    type='RetinaNet',   # 模型名称
    # 1.1 Backbone配置
    backbone=dict(
        type='ResNet',    # Backbone使用ResNet50(4阶段,50层)
        depth=50,     
        num_stages=4,
        out_indices=(0, 1, 2, 3), # 输出ResNet50第1~4阶段的feature map,供后续FPN做多尺度特征融合
        frozen_stages=1,      # 由于使用了预训练模型,冻结ResNet50第一阶段的网络参数,不参与训练过程
        norm_cfg=dict(type='BN', requires_grad=True), # 归一化层配置
        norm_eval=True,
        style='pytorch',
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),   # 使用pytorch提供的ResNet50在ImageNet上的预训练模型
    # 1.2 Neck配置
    neck=dict(
        type='FPN',   # Neck使用FPN
        in_channels=[256, 512, 1024, 2048],   # 输入通道数对应resnet50四个阶段feature map的维度
        out_channels=256,           # 输出特征维度为256
        start_level=1,              # 从Backbone的第一阶段特征图开始
        add_extra_convs='on_input',
        num_outs=5),
    # 1.3 Head配置
    bbox_head=dict(
        type='RetinaHead',  # Head使用RetinaHead
        num_classes=80,   # COCO数据集包含80类目标
        in_channels=256,  # FPN层输出特征维度为256
        stacked_convs=4,
        feat_channels=256,
        # 1.3.1 Retina是Anchor-Based方法, 需要生成Anchor
        anchor_generator=dict(
            type='AnchorGenerator',
            octave_base_scale=4,
            scales_per_octave=3,
            ratios=[0.5, 1.0, 2.0],
            strides=[8, 16, 32, 64, 128]),
        # 1.3.2 BBox Encoder配置
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[0.0, 0.0, 0.0, 0.0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        # 1.3.3 分类Loss函数
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
        # 1.3.4 回归Loss函数
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    # 1.4 训练配置
    train_cfg=dict(
      # 1.4.1 BBox Assigner
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.4,
            min_pos_iou=0,
            ignore_iof_thr=-1),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    # 1.5 测试配置
    test_cfg=dict(
        nms_pre=1000,
        min_bbox_size=0,
        score_thr=0.05,
        nms=dict(type='nms', iou_threshold=0.5),
        max_per_img=100))
# 2. 数据配置
data = dict(
    samples_per_gpu=2,    # batch_size大小
    workers_per_gpu=2,    # 每个GPU的线程数, 影响dataload的速度
    # 2.1 训练集配置
    train=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_train2017.json',      
        img_prefix='data/coco/train2017/',                  
        # 数据预处理步骤
        pipeline=[
            dict(type='LoadImageFromFile'),                 
            dict(type='LoadAnnotations', with_bbox=True),         
            dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),  
            dict(type='RandomFlip', flip_ratio=0.5),            
            dict(                             
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size_divisor=32),                
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
        ]),
    # 2.2 验证集配置
    val=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_val2017.json',
        img_prefix='data/coco/val2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]),
    # 2.3 测试集配置
    test=dict(
        type='CocoDataset',
        ann_file='data/coco/annotations/instances_val2017.json',
        img_prefix='data/coco/val2017/',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(1333, 800),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]))
# evalution hook配置
evaluation = dict(interval=1, metric='bbox')
# 优化器配置
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
# optimizer hook配置
optimizer_config = dict(grad_clip=None)
# 学习率配置
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])
# Runner配置
runner = dict(type='EpochBasedRunner', max_epochs=12)
# checkpoint配置
checkpoint_config = dict(interval=1)
# logger hook配置
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
# 自定义hook配置
custom_hooks = [dict(type='NumClassCheckHook')]
# 分布式训练配置
dist_params = dict(backend='nccl')
# 日志级别
log_level = 'INFO'
# 预训练模型路径
load_from = None
# 模型断点
resume_from = None
# Runner的工作流
workflow = [('train', 1)]


从配置文件可以看到,当前默认从pytorch官网下载预训练模型,且数据集的路径以及GPU数目和我当前的不符,而且由于电脑内存有限,我不希望每个epoch都保存一次checkpoint,所以我新建了一个配置文件my_retinanet_r50_fpn.py继承了官方的配置文件,并进行了一些修改:


_base_ = [
    'D:/Program Files/OpenSourceLib/mmdetection/configs/retinanet/retinanet_r50_fpn_1x_coco.py'
]
model = dict(
    backbone=dict(
        init_cfg=None)    # 不再直接从官网下载预训练模型,使用我自己下载好的预训练模型
)
data = dict(
    samples_per_gpu=2,    # batch_size=2
    workers_per_gpu=1,    # 每个GPU的线程数, 影响dataload的速度
    train=dict(
        type='CocoDataset',
        ann_file='E:/Dataset/COCO2014/annotations/instances_train2014.json',  # 修改数据集路径
        img_prefix='E:/Dataset/COCO2014/train2014'),
    val=dict(
        type='CocoDataset',
        ann_file='E:/Dataset/COCO2014/annotations/instances_val2014.json',
        img_prefix='E:/Dataset/COCO2014/val2014/'),
    test=dict(
        type='CocoDataset',
        ann_file='E:/Dataset/COCO2014/annotations/instances_val2014.json',
        img_prefix='E:/Dataset/COCO2014/val2014/')
)
evaluation = dict(interval=12, metric='bbox') # 12个epoch进行一次评估
checkpoint_config = dict(interval=2)  # 2个epoch保存一次checkpoint
load_from = '../ckpts/resnet50-0676ba61.pth'  # 自己下载的预训练模型路径


6.3 训练网络


写完配置文件后,就可以直接调用./tools/train.py指定配置文件进行训练。train.py包含了模型配置、数据集配置、训练配置、Hook配置等的解析,以及根据配置信息构造训练,用户的自定义操作可以通过Hook进行配置,一般无需修改train.py文件。


python train.py my_retinanet_r50_fpn.py


成功开始训练:


62a13f03ed5d43f2b3fdf0c9967940c2.png


7.总结


本文利用MMDetection已经实现的RetinaNet模型在COCO上进行训练作为示例,演示了MMDetection的模型训练流程。总的来说分为三个步骤:


1、准备数据集


2、准备配置文件:配置文件由一系列dict组成,dict中的type键值代表注册的类别,build函数可以通过识别dict中的type来初始化对应的类。配置文件一般会继承一个通用配置文件,然后在此基础上根据需求调整。


3、开始训练:调用MMDetection自带的train.py进行训练。


如果需要构建自己的模型,则需要实现一个类然后进行注册,Registry和Hook的机制。

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
目录
相关文章
|
1天前
|
测试技术 开发工具
自动化测试框架的演进与选择指南
随着软件开发行业的蓬勃发展,自动化测试已成为确保软件质量的重要手段。本文将探讨自动化测试框架的发展历程,分析不同类型框架的特点,并提供选择合适框架的建议,旨在帮助读者根据项目需求做出明智的决策。
|
1天前
|
资源调度 数据管理 测试技术
自动化测试框架的演进与实践
软件测试作为保障产品质量的重要环节,在技术迭代的浪潮中不断进化。本文从自动化测试框架的角度出发,探讨了其发展历程、当前面临的挑战以及未来的趋势。我们将深入分析自动化测试框架的设计原则、核心组件和最佳实践,旨在为读者提供一套系统化的解决方案,帮助团队提升测试效率,确保软件质量。
|
2天前
|
分布式计算 算法 Java
阿里云ODPS PySpark任务使用mmlspark/synapseml运行LightGBM进行Boosting算法的高效训练与推理
阿里云ODPS PySpark任务使用mmlspark/synapseml运行LightGBM进行Boosting算法的高效训练与推理
|
3天前
|
XML Java 测试技术
《手把手教你》系列基础篇(八十二)-java+ selenium自动化测试-框架设计基础-TestNG测试报告-上篇(详解教程)
【6月更文挑战第23天】TestNG 是一个用于自动化测试的 Java 框架,它自动生成测试报告,包括 HTML 和 XML 格式。报告可在 `test-output` 文件夹中找到。要创建测试用例,可创建一个实现了 `@Test` 注解的方法。通过 `testng.xml` 配置文件来组织和执行测试。默认报告包括测试结果、失败点和原因。用户还能实现 `ITestListener` 和 `IReporter` 接口来自定义报告和记录器。
20 2
|
3天前
|
机器学习/深度学习 人工智能 测试技术
自动化测试框架的演进与实践
【6月更文挑战第23天】在软件工程领域,自动化测试框架的发展不断推动着质量保证的效率和效果。本文将探讨自动化测试框架从简单脚本到复杂集成系统的演变过程,并分析当前流行的框架如Selenium、Appium以及新兴的AI驱动测试工具。我们将通过具体案例,展示如何在现代软件开发实践中有效应用这些框架以提升测试覆盖率和准确性。
|
4天前
|
Java 测试技术 Python
《手把手教你》系列基础篇(八十一)-java+ selenium自动化测试-框架设计基础-TestNG如何暂停执行一些case(详解教程)
【6月更文挑战第22天】本文介绍了如何在TestNG中不执行特定测试用例。当部分模块未准备好时,可以通过以下方式暂停测试:③使用`@Test(enabled=false)`注解来禁用测试用例。作者提供了一个Java Selenium自动化测试的示例,展示如何通过修改`enabled`参数控制测试方法的执行。代码中,`testSearch2()`方法被禁用,因此在测试运行时不执行。文章还包含了测试报告和执行过程的截图。
36 7
|
2天前
|
机器学习/深度学习 自然语言处理 算法
m基于深度学习的OFDM+QPSK链路信道估计和均衡算法误码率matlab仿真,对比LS,MMSE及LMMSE传统算法
**摘要:** 升级版MATLAB仿真对比了深度学习与LS、MMSE、LMMSE的OFDM信道估计算法,新增自动样本生成、复杂度分析及抗频偏性能评估。深度学习在无线通信中,尤其在OFDM的信道估计问题上展现潜力,解决了传统方法的局限。程序涉及信道估计器设计,深度学习模型通过学习导频信息估计信道响应,适应频域变化。核心代码展示了信号处理流程,包括编码、调制、信道模拟、降噪、信道估计和解调。
23 8
|
3天前
|
算法 计算机视觉
基于Chan-Vese算法的图像边缘提取matlab仿真
**算法预览展示了4幅图像,从边缘检测到最终分割,体现了在matlab2022a中应用的Chan-Vese水平集迭代过程。核心代码段用于更新水平集并显示迭代效果,最后生成分割结果及误差曲线。Chan-Vese模型(2001)是图像分割的经典方法,通过最小化能量函数自动检测平滑区域和清晰边界的图像分割,适用于复杂环境,广泛应用于医学影像和机器视觉。**
|
4天前
|
算法
基于GA遗传优化的混合发电系统优化配置算法matlab仿真
**摘要:** 该研究利用遗传算法(GA)对混合发电系统进行优化配置,旨在最小化风能、太阳能及电池储能的成本并提升系统性能。MATLAB 2022a用于实现这一算法。仿真结果展示了一系列图表,包括总成本随代数变化、最佳适应度随代数变化,以及不同数据的分布情况,如负荷、风速、太阳辐射、弃电、缺电和电池状态等。此外,代码示例展示了如何运用GA求解,并绘制了发电单元的功率输出和年变化。该系统原理基于GA的自然选择和遗传原理,通过染色体编码、初始种群生成、适应度函数、选择、交叉和变异操作来寻找最优容量配置,以平衡成本、效率和可靠性。
|
5天前
|
机器学习/深度学习 算法
基于鲸鱼优化的knn分类特征选择算法matlab仿真
**基于WOA的KNN特征选择算法摘要** 该研究提出了一种融合鲸鱼优化算法(WOA)与K近邻(KNN)分类器的特征选择方法,旨在提升KNN的分类精度。在MATLAB2022a中实现,WOA负责优化特征子集,通过模拟鲸鱼捕食行为的螺旋式和包围策略搜索最佳特征。KNN则用于评估特征子集的性能。算法流程包括WOA参数初始化、特征二进制编码、适应度函数定义(以分类准确率为基准)、WOA迭代搜索及最优解输出。该方法有效地结合了启发式搜索与机器学习,优化特征选择,提高分类性能。