MMdetection框架速成系列 第03部分:简述整体构建细节与模块+训练测试模块流程剖析+深入解析代码模块与核心实现

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
简介: 按照抽象到具体方式,从多个层次进行训练和测试流程深入解析,从最抽象层讲起,到最后核心代码实现,希望帮助大家更容易理解 MMDetection 开源框架整体构建细节

🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗🚗


MMdetection框架速成系列


MMdetection框架速成系列 第01部分:

https://v9999.blog.csdn.net/article/details/128486362

MMdetection框架速成系列 第02部分:

https://v9999.blog.csdn.net/article/details/128486548

MMdetection框架速成系列 第03部分:

https://v9999.blog.csdn.net/article/details/129294753

🚗🚗🚗🚗🚗🚗🚗正文开始🚗🚗🚗🚗🚗🚗🚗


按照抽象到具体方式,从多个层次进行训练和测试流程深入解析,从最抽象层讲起,到最后核心代码实现,希望帮助大家更容易理解 MMDetection 开源框架整体构建细节


1 第一层整体抽象


下图为 MMDetection 框架整体训练和测试抽象流程图。


3c093f6f082e43dba3029aedc8676de9.png


按照数据流过程,训练流程可以简单总结为:


1.给定任何一个数据集,首先需要构建 Dataset 类,用于迭代输出数据


2.在迭代输出数据的时候需要通过数据 Pipeline 对数据进行各种处理,最典型的处理流是训练中的数据增强操作,测试中的数据预处理等等


3.通过 Sampler 采样器可以控制 Dataset 输出的数据顺序,最常用的是随机采样器 RandomSampler。由于 Dataset 中输出的图片大小不一样,为了尽可能减少后续组成 batch 时 pad 的像素个数,MMDetection 引入了分组采样器 GroupSampler 和 DistributedGroupSampler,相当于在 RandomSampler 基础上额外新增了根据图片宽高比进行 group 功能


4.将 Sampler 和 Dataset 都输入给 DataLoader,然后通过 DataLoader 输出已组成 batch 的数据,作为 Model 的输入


5.对于任何一个 Model,为了方便处理数据流以及分布式需求,MMDetection 引入了两个 Model 的上层封装:单机版本 MMDataParallel、分布式(单机多卡或多机多卡)版本 MMDistributedDataParallel


6.Model 运行后会输出 loss 以及其他一些信息,会通过 logger 进行保存或者可视化


7.为了更好地解耦, 方便地获取各个组件之间依赖和灵活扩展,MMDetection 引入了 Runner 类进行全生命周期管理,并且通过 Hook 方便的获取、修改和拦截任何生命周期数据流,扩展非常便捷


按照数据流过程,测试流程可以简单总结为:


测试流程就比较简单,直接对 DataLoader 输出的数据进行前向推理即可,还原到最终原图尺度过程也是在 Model 中完成。


以上就是 MMDetection 框架整体训练和测试抽象流程,上图不仅仅反映了训练和测试数据流,而且还包括了模块和模块之间的调用关系。对于训练而言,最核心部分应该是 Runner,理解了 Runner 的运行流程,也就理解了整个 MMDetection 数据流。


2 第二层模块抽象


在总体把握了整个 MMDetection 框架训练和测试流程后,下个层次是每个模块内部抽象流程,主要包括 Pipeline、DataParallel、Model、Runner 和 Hooks。


2.1 Pipeline


Pipeline 实际上由一系列按照插入顺序运行的数据处理模块组成,每个模块完成某个特定功能,例如 Resize,因为其流式顺序运行特性,故叫做 Pipeline。


如下图所示,即非常典型的训练流程 Pipeline,每个类都接收字典输入,输出也是字典,顺序执行,其中绿色表示该类运行后新增字段,橙色表示对该字段可能会进行修改。


badfe894000c489ea33da5d626960e39.png


如果进一步细分的话,不同算法的 Pipeline 都可以划分为如下部分:


  • 图片和标签加载,通常用的类是 LoadImageFromFile 和 LoadAnnotations


  • 数据前处理,例如统一 Resize


  • 数据增强,典型的例如各种图片几何变换等,这部分是训练流程特有,测试阶段一般不采用(多尺度测试采用其他实现方式)【最常修改这里】


  • 数据收集,例如 Collect


在 MMDetection 框架中,图片和标签加载和数据后处理流程一般是固定的,用户主要可能修改的是数据增强步骤,目前已经接入了第三方增强库 Albumentations,可以按照示例代码轻松构建属于你自己的数据增强 Pipeline。


在构建自己的 Pipeline 时候一定要仔细检查修改或者新增的字典 key 和 value,因为一旦错误地覆盖或者修改原先字典里面的内容,代码也可能不会报错,如果出现 bug,则比较难排查。

2.2 DataParallel 和 Model


在 MMDetection 中 DataLoader 输出的内容不是 pytorch 能处理的标准格式,还包括了 DataContainer 对象,该对象的作用是包装不同类型的对象使之能按需组成 batch。


在目标检测中,每张图片 gt bbox 个数是不一样的,如果想组成 batch tensor,要么你设置最大长度,要么你自己想办法组成 batch。而考虑到内存和效率,MMDetection 通过引入 DataContainer 模块来解决上述问题,但是随之带来的问题是 pytorch 无法解析 DataContainer 对象,故需要在 MMDetection 中自行处理。


解决办法其实非常多,MMDetection 选择了一种比较优雅的实现方式:MMDataParallel 和 MMDistributedDataParallel。具体来说,这两个类相比 PyTorch 自带的 DataParallel 和 DistributedDataParallel 区别是:


  • 可以处理 DataContainer 对象


  • 额外实现了 train_step() 和 val_step() 两个函数,可以被 Runner 调用


关于这两个类的具体实现后面会描述。


而 Model 部分内容就是第一篇解读文章所讲的,具体如下:


dae5ca647e524fce83d6a9b505f51d77.png


2.3 Runner 和 Hooks


对于任何一个目标检测算法,都需要包括优化器、学习率设置、权重保存等等组件才能构成完整训练流程,而这些组件是通用的。


为了方便 OpenMMLab 体系下的所有框架复用,在 MMCV 框架中引入了 Runner 类来统一管理训练和验证流程,并且通过 Hooks 机制以一种非常灵活、解耦的方式来实现丰富扩展功能。


关于 Runner 和 Hooks 详细解读会发布在 MMCV 系列解读文章中,简单来说 Runner 封装了 OpenMMLab 体系下各个框架的训练和验证详细流程,其负责管理训练和验证过程中的整个生命周期,通过预定义回调函数,用户可以插入定制化 Hook ,从而实现各种各样的需求。下面列出了在 MMDetection 几个非常重要的 hook 以及其作用的生命周期:


bf62060c657c47b8a5b2f4abe89fff46.png


例如 CheckpointHook 在每个训练 epoch 完成后会被调用,从而实现保存权重功能。用户也可以将自己定制实现的 Hook 采用上述方式绘制,对理解整个流程或许有帮助。


3 第三层代码抽象


前面两层抽象分析流程,基本上把整个 MMDetection 的训练和测试流程分析完了,下面从具体代码层面进行抽象分析。


3.1 训练和测试整体代码抽象流程


487d0afb07c9461889b9ebd8cddbefd1.png


上图为训练和验证的和具体代码相关的整体抽象流程,对应到代码上,其核心代码如下:


#=================== tools/train.py ==================
# 1.初始化配置
cfg = Config.fromfile(args.config)
# 2.判断是否为分布式训练模式
# 3.初始化 logger
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# 4.收集运行环境并且打印,方便排查硬件和软件相关问题
env_info_dict = collect_env()
# 5.初始化 model
model = build_detector(cfg.model, ...)
# 6.初始化 datasets
#=================== mmdet/apis/train.py ==================
# 1.初始化 data_loaders ,内部会初始化 GroupSampler
data_loader = DataLoader(dataset,...)
# 2.基于是否使用分布式训练,初始化对应的 DataParallel
if distributed:
  model = MMDistributedDataParallel(...)
else:
  model = MMDataParallel(...)
# 3.初始化 runner
runner = EpochBasedRunner(...)
# 4.注册必备 hook
runner.register_training_hooks(cfg.lr_config, optimizer_config,
                               cfg.checkpoint_config, cfg.log_config,
                               cfg.get('momentum_config', None))
# 5.如果需要 val,则还需要注册 EvalHook           
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# 6.注册用户自定义 hook
runner.register_hook(hook, priority=priority)
# 7.权重恢复和加载
if cfg.resume_from:
    runner.resume(cfg.resume_from)
elif cfg.load_from:
    runner.load_checkpoint(cfg.load_from)
# 8.运行,开始训练
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)


上面的流程比较简单,一般大家比较难以理解的是 runner.run 内部逻辑,下小节进行详细分析,而对于测试逻辑由于比较简单,就不详细描述了,简单来说测试流程下不需要 runner,直接加载训练好的权重,然后进行 model 推理即可。


3.2 Runner 训练和验证代码抽象


runner 对象内部的 run 方式是一个通用方法,可以运行任何 workflow,目前常用的主要是 train 和 val。


  • 当配置为:workflow = [(‘train’, 1)],表示仅仅进行 train workflow,也就是迭代训练


  • 当配置为:workflow = [(‘train’, n),(‘val’, 1)],表示先进行 n 个 epoch 的训练,然后再进行1个 epoch 的验证,然后循环往复,如果写成 [(‘val’, 1),(‘train’, n)] 表示先进行验证,然后才开始训练


当进入对应的 workflow,则会调用 runner 里面的 train() 或者 val(),表示进行一次 epoch 迭代。其代码也非常简单,如下所示:


def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self.call_hook('before_train_epoch')
    for i, data_batch in enumerate(self.data_loader):
        self.call_hook('before_train_iter')
        self.run_iter(data_batch, train_mode=True)
        self.call_hook('after_train_iter')
    self.call_hook('after_train_epoch')
def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    self.call_hook('before_val_epoch')
    for i, data_batch in enumerate(self.data_loader):
        self.call_hook('before_val_iter')
        with torch.no_grad():
            self.run_iter(data_batch, train_mode=False)
        self.call_hook('after_val_iter')
    self.call_hook('after_val_epoch')


核心函数实际上是 self.run_iter(),如下:


def run_iter(self, data_batch, train_mode, **kwargs):
    if train_mode:
        # 对于每次迭代,最终是调用如下函数
        outputs = self.model.train_step(data_batch,...)
    else:
        # 对于每次迭代,最终是调用如下函数
        outputs = self.model.val_step(data_batch,...)
    if 'log_vars' in outputs:
        self.log_buffer.update(outputs['log_vars'],...)
    self.outputs = outputs


上述 self.call_hook() 表示在不同生命周期调用所有已经注册进去的 hook,而字符串参数表示对应的生命周期。


以 OptimizerHook 为例,其执行反向传播、梯度裁剪和参数更新等核心训练功能:


@HOOKS.register_module()
class OptimizerHook(Hook):
    def __init__(self, grad_clip=None):
        self.grad_clip = grad_clip
    def after_train_iter(self, runner):
        runner.optimizer.zero_grad()
        runner.outputs['loss'].backward()
        if self.grad_clip is not None:
            grad_norm = self.clip_grads(runner.model.parameters())
        runner.optimizer.step()


可以发现 OptimizerHook 注册到的生命周期是 after_train_iter,故在每次 train() 里面运行到self.call_hook('aftertrainiter') 时候就会被调用,其他 hook 也是同样运行逻辑。


3.3 Model 训练和测试代码抽象


训练和验证的时候实际上调用了 model 内部的 train_step 和 val_step 函数,因此理解这两个函数调用流程也就理解了 MMDetection 训练和测试流程。


注意,由于 model 对象会被 DataParallel 类包裹,故实际上此时的 model,是指的 MMDataParallel 或者 MMDistributedDataParallel。


以非分布式 train_step 流程为例,其内部完成调用流程图示如下:


4964fb4f688348a2885bf5672ed2728b.png


3.3.1 train 或者 val 流程


(1) 调用 runner 中的 train_step 或者 val_step


在 runner 中调用 train_step 或者 val_step,代码如下:


#=================== mmcv/runner/epoch_based_runner.py ==================
if train_mode:
    outputs = self.model.train_step(data_batch,...)
else:
    outputs = self.model.val_step(data_batch,...)
# 实际上,首先会调用 DataParallel 中的 train_step 或者 val_step ,其具体调用流程为:
# 非分布式训练
#=================== mmcv/parallel/data_parallel.py/MMDataParallel ==================
def train_step(self, *inputs, **kwargs):
    if not self.device_ids:
        # scatter():处理 DataContainer 格式数据,使其能够组成 batch,否则程序会报错。
        inputs, kwargs = self.scatter(inputs, kwargs, [-1])
        # 此时才是调用 model 本身的 train_step
        return self.module.train_step(*inputs, **kwargs)
    # 单 gpu 模式
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
# val_step 也是的一样逻辑
def val_step(self, *inputs, **kwargs):
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 val_step
    return self.module.val_step(*inputs[0], **kwargs[0])


可以发现,在调用 model 本身的 train_step 前,需要额外调用 scatter 函数,该函数的作用是处理 DataContainer 格式数据,使其能够组成 batch,否则程序会报错。


如果是分布式训练,则调用的实际上是 mmcv/parallel/distributed.py/MMDistributedDataParallel,最终调用的依然是 model 本身的 train_step 或者 val_step。


(2) 调用 model 中的 train_step 或者 val_step


其核心代码如下:


#=================== mmdet/models/detectors/base.py/BaseDetector ==================
def train_step(self, data, optimizer):
    # 调用本类自身的 forward 方法
    losses = self(**data)
    # 解析 loss
    loss, log_vars = self._parse_losses(losses)
    # 返回字典对象
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
    return outputs
def forward(self, img, img_metas, return_loss=True, **kwargs):
    if return_loss:
        # 训练模式
        return self.forward_train(img, img_metas, **kwargs)
    else:
        # 测试模式
        return self.forward_test(img, img_metas, **kwargs)


forward_train 和 forward_test 需要在不同的算法子类中实现,输出是 Loss 或者 预测结果。


(3) 调用子类中的 forward_train 方法


目前提供了两个具体子类,TwoStageDetector 和 SingleStageDetector ,用于实现 two-stage 和 single-stage 算法。


对于 TwoStageDetector 而言,其核心逻辑是:


#============= mmdet/models/detectors/two_stage.py/TwoStageDetector ============
def forward_train(...):
    # 先进行 backbone+neck 的特征提取
    x = self.extract_feat(img)
    losses = dict()
    # RPN forward and loss
    if self.with_rpn:
        # 训练 RPN
        proposal_cfg = self.train_cfg.get('rpn_proposal',
                                          self.test_cfg.rpn)
        # 主要是调用 rpn_head 内部的 forward_train 方法
        rpn_losses, proposal_list = self.rpn_head.forward_train(x,...)
        losses.update(rpn_losses)
    else:
        proposal_list = proposals
    # 第二阶段,主要是调用 roi_head 内部的 forward_train 方法
    roi_losses = self.roi_head.forward_train(x, ...)
    losses.update(roi_losses)
    return losses


对于 SingleStageDetector 而言,其核心逻辑是:


#============= mmdet/models/detectors/single_stage.py/SingleStageDetector ============
def forward_train(...):
    super(SingleStageDetector, self).forward_train(img, img_metas)
    # 先进行 backbone+neck 的特征提取
    x = self.extract_feat(img)
    # 主要是调用 bbox_head 内部的 forward_train 方法
    losses = self.bbox_head.forward_train(x, ...)
    return losses


如果再往里分析,那就到各个 Head 模块的训练环节了,这部分内容请读者自行分析,应该不难。


3.3.2 test 流程


由于没有 runner 对象,测试流程简单很多,下面简要概述:


1.调用 MMDataParallel 或 MMDistributedDataParallel 中的 forward 方法


2.调用 base.py 中的 forward 方法


3.调用 base.py 中的 self.forward_test 方法


4.如果是单尺度测试,则会调用 TwoStageDetector 或 SingleStageDetector 中的 simple_test 方法,如果是多尺度测试,则调用 aug_test 方法


5.最终调用的是每个具体算法 Head 模块的 simple_test 或者 aug_test 方法


4 总结


本文基于第一篇解读文章,详细地从三个层面全面解读了 MMDetection 框架,希望读者读完本文,能够对 MMDetection 框架设计思想、组件间关系和整体代码实现流程了然于心。

目录
相关文章
|
27天前
|
监控 jenkins 测试技术
自动化测试框架的构建与实践
【10月更文挑战第40天】在软件开发周期中,测试环节扮演着至关重要的角色。本文将引导你了解如何构建一个高效的自动化测试框架,并深入探讨其设计原则、实现方法及维护策略。通过实际代码示例和清晰的步骤说明,我们将一起探索如何确保软件质量,同时提升开发效率。
37 1
|
1月前
|
测试技术 开发者 Python
自动化测试之美:从零构建你的软件质量防线
【10月更文挑战第34天】在数字化时代的浪潮中,软件成为我们生活和工作不可或缺的一部分。然而,随着软件复杂性的增加,如何保证其质量和稳定性成为开发者面临的一大挑战。自动化测试,作为现代软件开发过程中的关键实践,不仅提高了测试效率,还确保了软件产品的质量。本文将深入浅出地介绍自动化测试的概念、重要性以及实施步骤,带领读者从零基础开始,一步步构建起属于自己的软件质量防线。通过具体实例,我们将探索如何有效地设计和执行自动化测试脚本,最终实现软件开发流程的优化和产品质量的提升。无论你是软件开发新手,还是希望提高项目质量的资深开发者,这篇文章都将为你提供宝贵的指导和启示。
|
19天前
|
自然语言处理 算法 Python
再谈递归下降解析器:构建一个简单的算术表达式解析器
本文介绍了递归下降解析器的原理与实现,重点讲解了如何使用Python构建一个简单的算术表达式解析器。通过定义文法、实现词法分析器和解析器类,最终实现了对基本算术表达式的解析与计算功能。
91 52
|
16天前
|
弹性计算 持续交付 API
构建高效后端服务:微服务架构的深度解析与实践
在当今快速发展的软件行业中,构建高效、可扩展且易于维护的后端服务是每个技术团队的追求。本文将深入探讨微服务架构的核心概念、设计原则及其在实际项目中的应用,通过具体案例分析,展示如何利用微服务架构解决传统单体应用面临的挑战,提升系统的灵活性和响应速度。我们将从微服务的拆分策略、通信机制、服务发现、配置管理、以及持续集成/持续部署(CI/CD)等方面进行全面剖析,旨在为读者提供一套实用的微服务实施指南。
|
13天前
|
监控 数据管理 测试技术
API接口自动化测试深度解析与最佳实践指南
本文详细介绍了API接口自动化测试的重要性、核心概念及实施步骤,强调了从明确测试目标、选择合适工具、编写高质量测试用例到构建稳定测试环境、执行自动化测试、分析测试结果、回归测试及集成CI/CD流程的全过程,旨在为开发者提供一套全面的技术指南,确保API的高质量与稳定性。
|
21天前
|
监控 持续交付 数据库
构建高效的后端服务:微服务架构的深度解析
在现代软件开发中,微服务架构已成为提升系统可扩展性、灵活性和维护性的关键。本文深入探讨了微服务架构的核心概念、设计原则和最佳实践,通过案例分析展示了如何在实际项目中有效地实施微服务策略,以及面临的挑战和解决方案。文章旨在为开发者提供一套完整的指导框架,帮助他们构建出更加高效、稳定的后端服务。
|
24天前
|
jenkins 测试技术 持续交付
自动化测试框架的构建与优化:提升软件交付效率的关键####
本文深入探讨了自动化测试框架的核心价值,通过对比传统手工测试方法的局限性,揭示了自动化测试在现代软件开发生命周期中的重要性。不同于常规摘要仅概述内容,本部分强调了自动化测试如何显著提高测试覆盖率、缩短测试周期、降低人力成本,并促进持续集成/持续部署(CI/CD)流程的实施,最终实现软件质量和开发效率的双重飞跃。通过具体案例分析,展示了从零开始构建自动化测试框架的策略与最佳实践,包括选择合适的工具、设计高效的测试用例结构、以及如何进行性能调优等关键步骤。此外,还讨论了在实施过程中可能遇到的挑战及应对策略,为读者提供了一套可操作的优化指南。 ####
|
25天前
|
敏捷开发 监控 测试技术
探索自动化测试框架的构建与优化####
在软件开发周期中,自动化测试扮演着至关重要的角色。本文旨在深入探讨如何构建高效的自动化测试框架,并分享一系列实用策略以提升测试效率和质量。我们将从框架选型、结构设计、工具集成、持续集成/持续部署(CI/CD)、以及最佳实践等多个维度进行阐述,为软件测试人员提供一套系统化的实施指南。 ####
|
10天前
|
监控 搜索推荐 测试技术
电商API的测试与用途:深度解析与实践
在电子商务蓬勃发展的今天,电商API成为连接电商平台、商家、消费者和第三方开发者的重要桥梁。本文深入探讨了电商API的核心功能,包括订单管理、商品管理、用户管理、支付管理和物流管理,并介绍了有效的测试技巧,如理解API文档、设计测试用例、搭建测试环境、自动化测试、压力测试、安全性测试等。文章还详细阐述了电商API的多样化用途,如商品信息获取、订单管理自动化、用户数据管理、库存同步、物流跟踪、支付处理、促销活动管理、评价管理、数据报告和分析、扩展平台功能及跨境电商等,旨在为开发者和电商平台提供有益的参考。
18 0
|
1月前
|
监控 安全 测试技术
构建高效的精准测试平台:设计与实现指南
在软件开发过程中,精准测试是确保产品质量和性能的关键环节。一个精准的测试平台能够自动化测试流程,提高测试效率,缩短测试周期,并提供准确的测试结果。本文将分享如何设计和实现一个精准测试平台,从需求分析到技术选型,再到具体的实现步骤。
119 1

推荐镜像

更多