YOLOv5改进 | 主干网络 | 将backbone替换为MobileNetV3【小白必备教程+附完整代码】

本文涉及的产品
视觉智能开放平台,图像通用资源包5000点
视觉智能开放平台,分割抠图1万点
视觉智能开放平台,视频通用资源包5000点
简介: 本文介绍了将YOLOv5的backbone替换为MobileNetV3以提升目标检测性能的教程。MobileNetV3采用倒残差结构、Squeeze-and-Excitation模块和Hard-Swish激活函数,实现更高性能和更低计算成本。文中提供了详细的代码实现,包括MobileNetV3的关键组件和YOLOv5的配置修改,便于读者实践。此外,还分享了完整代码链接和进一步的进阶策略,适合深度学习初学者和进阶者学习YOLO系列。
💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

尽管Ultralytics 推出了最新版本的 YOLOv8 模型。但YOLOv5作为一个anchor base的目标检测的算法,YOLOv5可能比YOLOv8的效果更好。注意力机制是提高模型性能最热门的方法之一,本文给大家带来的教程是将YOLOv5的backbone替换为MobileNetV3结构来提取特征。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址:YOLOv5改进+入门——持续更新各种有效涨点方法

1.原理

MobileNetV3 是 Google 提出的一种轻量级神经网络结构,旨在在移动设备上实现高效的图像识别和分类任务。与之前的 MobileNet 系列相比,MobileNetV3 在模型结构和性能上都有所改进。

MobileNetV3 的结构主要包括以下几个关键组件:

  1. 基础模块(Base Module):MobileNetV3 使用了一种称为“倒残差”(Inverted Residuals)的基础模块结构。该结构采用了深度可分离卷积和线性瓶颈,以减少参数数量和计算复杂度,并且保持了模型的有效性。

  2. Squeeze-and-Excitation 模块:MobileNetV3 引入了 Squeeze-and-Excitation 模块,通过学习通道之间的相互关系,动态地调整通道权重,以增强模型的表征能力。这有助于提高模型对关键特征的感知能力,从而提高分类性能。

  3. Hard-Swish 激活函数:MobileNetV3 使用了一种称为 Hard-Swish 的激活函数。与传统的 ReLU 激活函数相比,Hard-Swish 具有更快的计算速度和更好的性能。

  4. 网络架构优化:MobileNetV3 在网络结构上进行了优化,包括通过网络宽度和分辨率的动态调整,以适应不同的计算资源和任务需求。

总体而言,MobileNetV3 通过这些创新设计和优化,实现了更高的性能和更低的计算成本,使其成为移动设备上图像识别任务的理想选择之一。

精简原理参考:百面算法工程师 | 分类网络总结-CSDN博客

2.MobileNetv3代码实现

2.1 将MobileNetv3代码添加到YOLOv5中

image.png
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        y = self.sigmoid(x)
        return x * y


class SELayer(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // reduction),
                nn.ReLU(inplace=True),
                nn.Linear(channel // reduction, channel),
                h_sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x)
        y = y.view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class Conv3BN(nn.Module):

    def __init__(self, inp, oup, stride):
        super(Conv3BN, self).__init__()
        self.conv = nn.Conv2d(inp, oup, 3, stride, 1, bias=False)
        self.bn = nn.BatchNorm2d(oup)
        self.act = h_swish()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def fuseforward(self, x):
        return self.act(self.conv(x))

class MobileNetv3(nn.Module):
    def __init__(self, inp, oup, hidden_dim, kernel_size, stride, use_se, use_hs):
        super(MobileNetv3, self).__init__()
        assert stride in [1, 2]

        self.identity = stride == 1 and inp == oup

        if inp == hidden_dim:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                h_swish() if use_hs else nn.ReLU(inplace=True),
                # Squeeze-and-Excite
                SELayer(hidden_dim) if use_se else nn.Sequential(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                h_swish() if use_hs else nn.ReLU(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                # Squeeze-and-Excite
                SELayer(hidden_dim) if use_se else nn.Sequential(),
                h_swish() if use_hs else nn.ReLU(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        y = self.conv(x)
        if self.identity:
            return x + y
        else:
            return y

MobileNetV3 的主要处理流程如下:

  1. 输入处理:输入图像首先经过预处理步骤,例如归一化和大小调整,以使其适应网络的输入要求。

  2. 特征提取:经过输入处理后,图像通过一系列基础模块(Base Module)进行特征提取。每个基础模块通常包含深度可分离卷积、激活函数(如 Hard-Swish)和通道注意力(如 Squeeze-and-Excitation)模块。

  3. 特征增强:在特征提取的过程中,通过 Squeeze-and-Excitation 模块对提取的特征图进行增强,以加强对重要特征的感知能力。

  4. 全局平均池化:在特征提取的最后阶段,通过全局平均池化操作将特征图的空间维度降低到一个固定大小。

  5. 分类器:全局平均池化后的特征图输入到分类器中,进行分类或其他任务的预测。分类器通常由一个或多个全连接层组成,最后输出预测结果。

整个流程通过堆叠和连接这些组件来完成。MobileNetV3 的设计旨在保持模型的轻量化和高效性能,以适应移动设备上的图像识别和分类任务。

image.png

2.2新增yaml文件

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# parameters
nc: 2  # number of classes
depth_multiple: 1.0  # dont change this otherwise InvertedResidual will be affected
width_multiple: 1.0  # dont change this otherwise InvertedResidual will be affected

# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 backbone
backbone:
  # MobileNetV3-large
  # [from, number, module, args]
  [[-1, 1, Conv3BN, [16, 2]],                          # 0-p1/2
   [-1, 1, MobileNetv3, [ 16,  16, 3, 1, 0, 0]],  # 1-p1/2
   [-1, 1, MobileNetv3, [ 24,  64, 3, 2, 0, 0]],  # 2-p2/4
   [-1, 1, MobileNetv3, [ 24,  72, 3, 1, 0, 0]],  # 3-p2/4
   [-1, 1, MobileNetv3, [ 40,  72, 5, 2, 1, 0]],  # 4-p3/8
   [-1, 1, MobileNetv3, [ 40, 120, 5, 1, 1, 0]],  # 5-p3/8
   [-1, 1, MobileNetv3, [ 40, 120, 5, 1, 1, 0]],  # 6-p3/8
   [-1, 1, MobileNetv3, [ 80, 240, 3, 2, 0, 1]],  # 7-p4/16
   [-1, 1, MobileNetv3, [ 80, 200, 3, 1, 0, 1]],  # 8-p4/16
   [-1, 1, MobileNetv3, [ 80, 184, 3, 1, 0, 1]],  # 9-p4/16
   [-1, 1, MobileNetv3, [ 80, 184, 3, 1, 0, 1]],  # 10-p4/16
   [-1, 1, MobileNetv3, [112, 480, 3, 1, 1, 1]],  # 11-p4/16
   [-1, 1, MobileNetv3, [112, 672, 3, 1, 1, 1]],  # 12-p4/16
   [-1, 1, MobileNetv3, [160, 672, 5, 1, 1, 1]],  # 13-p4/16
   [-1, 1, MobileNetv3, [160, 672, 5, 2, 1, 1]],  # 14-p5/32
   [-1, 1, MobileNetv3, [160, 960, 5, 1, 1, 1]],  # 15-p5/32
  ]

# YOLOv5 head
head:
  [[-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 13], 1, Concat, [1]],  # cat backbone P4
   [-1, 1, C3, [256, False]],  # 19

   [-1, 1, Conv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P3
   [-1, 1, C3, [128, False]],  # 23 (P3/8-small)

   [-1, 1, Conv, [128, 3, 2]],
   [[-1, 20], 1, Concat, [1]],  # cat head P4
   [-1, 1, C3, [256, False]],  # 26 (P4/16-medium)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 16], 1, Concat, [1]],  # cat head P5
   [-1, 1, C3, [512, False]],  # 29 (P5/32-large)

   [[23, 26, 29], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

温馨提示:因为本文只是对yolov5n基础上添加swin模块,如果要对yolov5n/l/m/x进行添加则只需要修改对应的depth_multiple 和 width_multiple。

yolov5n/l/m/x对应的depth_multiple 和 width_multiple如下:

# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple

# YOLOv5s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple

# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple

# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple

# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple

2.3 注册模块

image.png

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_mobilemetv3.yaml的路径,如下图所示

建议大家写绝对路径,确保一定能找到

🚀运行程序,如果出现下面的内容则说明添加成功🚀

image.png

3.完整代码分享

YOLOv5改进 | 主干网络 | 将backbone替换为MobileNetV3【小白必备教程+附完整代码】

👆是我修改后的完整,提取码unah

4.总结

MobileNetV3 是一种轻量级神经网络结构,通过倒残差基础模块、Squeeze-and-Excitation 模块和Hard-Swish 激活函数等创新技术,实现了在移动设备等资源受限环境下高效的图像识别和分类任务。其处理流程包括输入处理、特征提取、特征增强、全局平均池化和分类器预测,通过优化网络结构和设计,MobileNetV3 在保持性能的同时减少了计算负载,适用于移动设备上的应用场景。

5. 进阶

在backbone中单独添加一层或者是替换一个C3模块,也是可行的策略,这里直接放上改进后的代码。提示:也可以尝试替换其他啊部位的C3模块。

YOLOv5改进 | 主干网络 | 将backbone替换为MobileNetV3【小白必备教程+附完整代码】
👆是我修改后的完整,提取码: 6wpu

可能在实验过程可能会遇到问题,如果你不能解决,欢迎评论区提问

相关文章
|
26天前
|
机器学习/深度学习 算法 测试技术
图神经网络在信息检索重排序中的应用:原理、架构与Python代码解析
本文探讨了基于图的重排序方法在信息检索领域的应用与前景。传统两阶段检索架构中,初始检索速度快但结果可能含噪声,重排序阶段通过强大语言模型提升精度,但仍面临复杂需求挑战
68 0
图神经网络在信息检索重排序中的应用:原理、架构与Python代码解析
|
1月前
|
Cloud Native 区块链 数据中心
Arista CloudEOS 4.32.2F - 云网络基础架构即代码
Arista CloudEOS 4.32.2F - 云网络基础架构即代码
41 1
|
2月前
|
数据采集 存储 监控
Python 原生爬虫教程:网络爬虫的基本概念和认知
网络爬虫是一种自动抓取互联网信息的程序,广泛应用于搜索引擎、数据采集、新闻聚合和价格监控等领域。其工作流程包括 URL 调度、HTTP 请求、页面下载、解析、数据存储及新 URL 发现。Python 因其丰富的库(如 requests、BeautifulSoup、Scrapy)和简洁语法成为爬虫开发的首选语言。然而,在使用爬虫时需注意法律与道德问题,例如遵守 robots.txt 规则、控制请求频率以及合法使用数据,以确保爬虫技术健康有序发展。
269 31
|
2月前
|
域名解析 API PHP
VM虚拟机全版本网盘+免费本地网络穿透端口映射实时同步动态家庭IP教程
本文介绍了如何通过网络穿透技术让公网直接访问家庭电脑,充分发挥本地硬件性能。相比第三方服务受限于转发带宽,此方法利用自家宽带实现更高效率。文章详细讲解了端口映射教程,包括不同网络环境(仅光猫、光猫+路由器)下的设置步骤,并提供实时同步动态IP的两种方案:自建服务器或使用三方API接口。最后附上VM虚拟机全版本下载链接,便于用户在穿透后将服务运行于虚拟环境中,提升安全性与适用性。
|
4月前
|
机器学习/深度学习 自然语言处理 计算机视觉
RT-DETR改进策略【Backbone/主干网络】| CVPR 2024 替换骨干网络为 RMT,增强空间信息的感知能力
RT-DETR改进策略【Backbone/主干网络】| CVPR 2024 替换骨干网络为 RMT,增强空间信息的感知能力
195 13
RT-DETR改进策略【Backbone/主干网络】| CVPR 2024 替换骨干网络为 RMT,增强空间信息的感知能力
|
4月前
|
机器学习/深度学习 计算机视觉 网络架构
RT-DETR改进策略【Backbone/主干网络】| CVPR 2024替换骨干网络为 UniRepLKNet,解决大核 ConvNets 难题
RT-DETR改进策略【Backbone/主干网络】| CVPR 2024替换骨干网络为 UniRepLKNet,解决大核 ConvNets 难题
343 12
RT-DETR改进策略【Backbone/主干网络】| CVPR 2024替换骨干网络为 UniRepLKNet,解决大核 ConvNets 难题
|
4月前
|
监控 Linux PHP
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
【02】客户端服务端C语言-go语言-web端PHP语言整合内容发布-优雅草网络设备监控系统-2月12日优雅草简化Centos stream8安装zabbix7教程-本搭建教程非docker搭建教程-优雅草solution
132 20
|
6月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
164 17
|
6月前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
124 10
|
6月前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。