MMsegmentation教程 4: 自定义模型

简介: MMsegmentation教程 4: 自定义模型

自定义优化器 (optimizer)


假设您想增加一个新的叫 MyOptimizer 的优化器,它的参数分别为 a, b, 和 c

您首先需要在一个文件里实现这个新的优化器,例如在 mmseg/core/optimizer/my_optimizer.py 里面:

from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):
    def __init__(self, a, b, c)


然后增加这个模块到 mmseg/core/optimizer/__init__.py 里面,这样注册器 (registry) 将会发现这个新的模块并添加它:

from .my_optimizer import MyOptimizer


之后您可以在配置文件的 optimizer 域里使用 MyOptimizer

如下所示,在配置文件里,优化器被 optimizer 域所定义:

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)


为了使用您自己的优化器,域可以被修改为:

optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)


我们已经支持了 PyTorch 自带的全部优化器,唯一修改的地方是在配置文件里的 optimizer 域。例如,如果您想使用 ADAM,尽管数值表现会掉点,还是可以如下修改:

optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)


使用者可以直接按照 PyTorch 文档教程 去设置参数。


定制优化器的构造器 (optimizer constructor)


对于优化,一些模型可能会有一些特别定义的参数,例如批归一化 (BatchNorm) 层里面的权重衰减 (weight decay)。

使用者可以通过定制优化器的构造器来微调这些细粒度的优化器参数。

from mmcv.utils import build_from_cfg
from mmcv.runner import OPTIMIZER_BUILDERS
from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):
    def __init__(self, optimizer_cfg, paramwise_cfg=None):
    def __call__(self, model):
        return my_optimizer


开发和增加新的组件(Module)


MMSegmentation 里主要有2种组件:

  • 主干网络 (backbone): 通常是卷积网络的堆叠,来做特征提取,例如 ResNet, HRNet
  • 解码头 (decoder head): 用于语义分割图的解码的组件(得到分割结果)


添加新的主干网络


这里我们以 MobileNet 为例,展示如何增加新的主干组件:

  1. 创建一个新的文件 mmseg/models/backbones/mobilenet.py

import torch.nn as nn
from ..registry import BACKBONES
@BACKBONES.register_module
class MobileNet(nn.Module):
    def __init__(self, arg1, arg2):
        pass
    def forward(self, x):  # should return a tuple
        pass
    def init_weights(self, pretrained=None):
        pass


  1. mmseg/models/backbones/__init__.py 里面导入模块

from .mobilenet import MobileNet


  1. 在您的配置文件里使用它

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...


增加新的解码头 (decoder head)组件


在 MMSegmentation 里面,对于所有的分割头,我们提供一个基类解码头 BaseDecodeHead

所有新建的解码头都应该继承它。这里我们以 PSPNet 为例,

展示如何开发和增加一个新的解码头组件:

首先,在 mmseg/models/decode_heads/psp_head.py 里添加一个新的解码头。

PSPNet 中实现了一个语义分割的解码头。为了实现一个解码头,我们只需要在新构造的解码头中实现如下的3个函数:

@HEADS.register_module()
class PSPHead(BaseDecodeHead):
    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super(PSPHead, self).__init__(**kwargs)
    def init_weights(self):
    def forward(self, inputs):


接着,使用者需要在 mmseg/models/decode_heads/__init__.py 里面添加这个模块,这样对应的注册器 (registry) 可以查找并加载它们。

PSPNet的配置文件如下所示:

norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='PSPHead',
        in_channels=2048,
        in_index=3,
        channels=512,
        pool_scales=(1, 2, 3, 6),
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))


增加新的损失函数


假设您想添加一个新的损失函数 MyLoss 到语义分割解码器里。

为了添加一个新的损失函数,使用者需要在 mmseg/models/losses/my_loss.py 里面去实现它。

weighted_loss 可以对计算损失时的每个样本做加权。

import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss
@LOSSES.register_module
class MyLoss(nn.Module):
    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight
    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss


然后使用者需要在 mmseg/models/losses/__init__.py 里面添加它:

from .my_loss import MyLoss, my_loss


为了使用它,修改 loss_xxx 域。之后您需要在解码头组件里修改 loss_decode 域。

loss_weight 可以被用来对不同的损失函数做加权。

loss_decode=dict(type='MyLoss', loss_weight=1.0))


相关文章
|
8月前
|
机器学习/深度学习 Python
Scikit-Learn 高级教程——自定义评估器
Scikit-Learn 高级教程——自定义评估器【1月更文挑战第17篇】
132 1
|
2月前
|
Java
开发指南072-模型定义
平台当中有些对象是自定义表结构,时髦的说法就是模型
|
3月前
|
机器学习/深度学习 并行计算 数据可视化
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
这篇文章介绍了如何使用PaddleClas框架完成多标签分类任务,包括数据准备、环境搭建、模型训练、预测、评估等完整流程。
195 0
|
7月前
|
机器学习/深度学习 人工智能 Java
人工智能平台PAI产品使用合集之已经通过自定义镜像部署了一个模型,想要上传并导入其他模型,该如何操作
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
数据可视化 PyTorch 算法框架/工具
量化自定义PyTorch模型入门教程
在以前Pytorch只有一种量化的方法,叫做“eager mode qunatization”,在量化我们自定定义模型时经常会产生奇怪的错误,并且很难解决。但是最近,PyTorch发布了一种称为“fx-graph-mode-qunatization”的方方法。在本文中我们将研究这个fx-graph-mode-qunatization”看看它能不能让我们的量化操作更容易,更稳定。
256 0
|
8月前
|
机器学习/深度学习 Python
Scikit-Learn 高级教程——高级模型
Scikit-Learn 高级教程——高级模型【1月更文挑战第19篇】
140 5
|
8月前
|
机器学习/深度学习 JSON 自然语言处理
python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)
python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)
|
Kubernetes 安全 Docker
k8s教程(基础篇)-安装与配置概述
k8s教程(基础篇)-安装与配置概述
401 0
|
数据处理
InVEST模型的下载及入门操作(以InVEST3.13.0为例)
InVEST是一套免费的开源软件模型,是美国自然资本项目组开发的、用于评估生态系统服务功能量及其经济价值、支持生态系统管理和决策的一套模型系统,用于绘制和评估维持和实现人类生活的自然商品和服务。包括商品生产(如食物)、生命维持过程(如水净化)和充实生命的条件(如美丽、娱乐机会)以及选择的保护(如未来使用的遗传多样性)等模块。(翻译自模型官网)
2061 1
|
存储 数据可视化 Ubuntu
bcftools学习笔记丨软件简介、安装方式、使用方法、核心功能、参数解释等一文速览
bcftools学习笔记丨软件简介、安装方式、使用方法、核心功能、参数解释等一文速览