【轻量化:实操】动手实现神经网络中的裁枝操作(附演示代码&yolo系列)

简介: 【轻量化:实操】动手实现神经网络中的裁枝操作(附演示代码&yolo系列)

前言

   神经网络中的裁剪操作指的是去除网络中一些不必要的权重或节点,从而减少网络的复杂度和计算量,提高网络的效率和泛化能力。

裁枝简介

  神经网络模型中的参数数量很大,模型大小可能会限制模型的部署和使用,裁剪可以有效地减少网络中的参数数量,从而减小模型的大小,使其更加易于部署和使用;在神经网络在推理时需要进行大量的乘加操作,计算量非常大,裁剪可以减少网络中的冗余连接和参数,从而降低计算量,提高网络的推理速度;裁剪可以去除网络中的一些冗余连接和参数,从而减少网络的过拟合风险,提高网络的泛化能力;神经网络通常被认为是“黑盒子”,很难理解它们是如何做出决策的。裁剪可以去除网络中的一些冗余连接和参数,从而提高网络的可解释性,使人们更容易理解网络是如何做出决策的;裁剪可以降低神经网络的存储和计算要求,从而使神经网络更加适合在嵌入式设备等计算资源受限的环境中使用

总结:

  1. 减少模型大小
  2. 加速推理
  3. 提高泛化能力
  4. 更好的可解释性
  5. 更低的存储和计算要求

实现步骤

  神经网络裁剪(Pruning)是一种常见的神经网络压缩技术,可以通过去除网络中冗余的连接和参数来减少网络的大小和计算量,从而提高网络的效率和泛化能力。以下是在神经网络中进行裁枝操作的一般步骤:

  1. 训练原始网络:首先需要训练出一个较为精度的原始网络,该网络可以使用任何一种常见的训练算法进行训练,例如随机梯度下降(SGD)、Adam等。
  2. 计算权重重要性:使用一种重要性指标来衡量每个参数的重要性,例如绝对值大小、梯度大小等。一般情况下,重要性指标越小的参数越容易被裁剪。
  3. 去除冗余参数:将网络中重要性指标较小的参数或连接去除,可以通过将它们的权重值设为零或者将其对应的连接删去来实现。通常情况下,裁剪的比例在20%~90%之间。
  4. 重新训练网络:对裁剪后的网络进行微调,以便保持其精度和泛化能力。

实操步骤

  我们对较为经典的一个网络进行裁枝操作,在这里我们选择较为简单的网络作为模板进行操作。这里选择图像分类任务(步骤详见:# 【基础实操】借用torch自带网络进行训练自己的图像数据 ) 作为基础任务向大家展示:

ini

复制代码

import torch.nn as nn
import torch.nn.utils.prune as prune
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.relu2 = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.relu3 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        return x
# 定义一个网络实例
net = Net()
# 对网络进行剪枝操作
parameters_to_prune = (
    (net.conv1, 'weight'),
    (net.conv2, 'weight'),
    (net.fc1, 'weight'),
    (net.fc2, 'weight'),
    (net.fc3, 'weight'),
)
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
# 使用剪枝后的网络进行训练

对于上述网络中我们可以得到一个10分类的网络。

yolo 系列

  在bubbliiiing博主的yolo系列的代码中,我们在train.py文件中找到创建yolo模型,在载入model后,对bubbliiiing博主的主干网络中的三个特征图进行压缩裁枝操作。下面是部分代码,可供参考。

ini

复制代码

#------------------------------------------------------#
#   创建yolo模型
#------------------------------------------------------#
model = YoloBody(anchors_mask, num_classes, phi, pretrained=pretrained)
parameters_to_prune = (
    (model.yolo_head_P3, 'weight'),
    (model.yolo_head_P3, 'weight'),
    (model.yolo_head_P3, 'weight'),
)
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)


相关文章
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
用MASM32按Time Protocol(RFC868)协议编写网络对时程序中的一些有用的函数代码
|
1月前
|
机器学习/深度学习 网络架构 计算机视觉
目标检测笔记(一):不同模型的网络架构介绍和代码
这篇文章介绍了ShuffleNetV2网络架构及其代码实现,包括模型结构、代码细节和不同版本的模型。ShuffleNetV2是一个高效的卷积神经网络,适用于深度学习中的目标检测任务。
68 1
目标检测笔记(一):不同模型的网络架构介绍和代码
|
22天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLO11改进 - C3k2融合】C3k2融合YOLO-MS的MSBlock : 分层特征融合策略,轻量化网络结构
【YOLO11改进 - C3k2融合】C3k2融合YOLO-MS的MSBlock : 分层特征融合策略,轻量化网络结构
|
2月前
|
安全 C#
某网络硬盘网站被植入传播Trojan.DL.Inject.xz等的代码
某网络硬盘网站被植入传播Trojan.DL.Inject.xz等的代码
|
3月前
|
安全 网络安全 开发者
探索Python中的装饰器:简化代码,增强功能网络安全与信息安全:从漏洞到防护
【8月更文挑战第30天】本文通过深入浅出的方式介绍了Python中装饰器的概念、用法和高级应用。我们将从基础的装饰器定义开始,逐步深入到如何利用装饰器来改进代码结构,最后探讨其在Web框架中的应用。适合有一定Python基础的开发者阅读,旨在帮助读者更好地理解并运用装饰器来优化他们的代码。
完成切换网络+修改网络连接图标提示的代码框架
完成切换网络+修改网络连接图标提示的代码框架
|
3月前
|
达摩院 供应链 JavaScript
网络流问题--仓储物流调度【数学规划的应用(含代码)】阿里达摩院MindOpt
本文通过使用MindOpt工具优化仓储物流调度问题,旨在提高物流效率并降低成本。首先,通过考虑供需匹配、运输时间与距离、车辆容量、仓库储存能力等因素构建案例场景。接着,利用数学规划方法,包括线性规划和网络流问题,来建立模型。在网络流问题中,通过定义节点(资源)和边(资源间的关系),确保流量守恒和容量限制条件下找到最优解。文中还详细介绍了MindOpt Studio云建模平台和MindOpt APL建模语言的应用,并通过实例展示了如何声明集合、参数、变量、目标函数及约束条件,并最终解析了求解结果。通过这些步骤,实现了在满足各仓库需求的同时最小化运输成本的目标。
|
3月前
|
开发者 图形学 API
从零起步,深度揭秘:运用Unity引擎及网络编程技术,一步步搭建属于你的实时多人在线对战游戏平台——详尽指南与实战代码解析,带你轻松掌握网络化游戏开发的核心要领与最佳实践路径
【8月更文挑战第31天】构建实时多人对战平台是技术与创意的结合。本文使用成熟的Unity游戏开发引擎,从零开始指导读者搭建简单的实时对战平台。内容涵盖网络架构设计、Unity网络API应用及客户端与服务器通信。首先,创建新项目并选择适合多人游戏的模板,使用推荐的网络传输层。接着,定义基本玩法,如2D多人射击游戏,创建角色预制件并添加Rigidbody2D组件。然后,引入网络身份组件以同步对象状态。通过示例代码展示玩家控制逻辑,包括移动和发射子弹功能。最后,设置服务器端逻辑,处理客户端连接和断开。本文帮助读者掌握构建Unity多人对战平台的核心知识,为进一步开发打下基础。
120 0
|
3月前
|
安全 开发者 数据安全/隐私保护
Xamarin 的安全性考虑与最佳实践:从数据加密到网络防护,全面解析构建安全移动应用的六大核心技术要点与实战代码示例
【8月更文挑战第31天】Xamarin 的安全性考虑与最佳实践对于构建安全可靠的跨平台移动应用至关重要。本文探讨了 Xamarin 开发中的关键安全因素,如数据加密、网络通信安全、权限管理等,并提供了 AES 加密算法的代码示例。
59 0
|
3月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
43 0

热门文章

最新文章