前言
神经网络中的裁剪操作指的是去除网络中一些不必要的权重或节点,从而减少网络的复杂度和计算量,提高网络的效率和泛化能力。
裁枝简介
神经网络模型中的参数数量很大,模型大小可能会限制模型的部署和使用,裁剪可以有效地减少网络中的参数数量,从而减小模型的大小,使其更加易于部署和使用;在神经网络在推理时需要进行大量的乘加操作,计算量非常大,裁剪可以减少网络中的冗余连接和参数,从而降低计算量,提高网络的推理速度;裁剪可以去除网络中的一些冗余连接和参数,从而减少网络的过拟合风险,提高网络的泛化能力;神经网络通常被认为是“黑盒子”,很难理解它们是如何做出决策的。裁剪可以去除网络中的一些冗余连接和参数,从而提高网络的可解释性,使人们更容易理解网络是如何做出决策的;裁剪可以降低神经网络的存储和计算要求,从而使神经网络更加适合在嵌入式设备等计算资源受限的环境中使用
总结:
- 减少模型大小
- 加速推理
- 提高泛化能力
- 更好的可解释性
- 更低的存储和计算要求
实现步骤
神经网络裁剪(Pruning)是一种常见的神经网络压缩技术,可以通过去除网络中冗余的连接和参数来减少网络的大小和计算量,从而提高网络的效率和泛化能力。以下是在神经网络中进行裁枝操作的一般步骤:
- 训练原始网络:首先需要训练出一个较为精度的原始网络,该网络可以使用任何一种常见的训练算法进行训练,例如随机梯度下降(SGD)、Adam等。
- 计算权重重要性:使用一种重要性指标来衡量每个参数的重要性,例如绝对值大小、梯度大小等。一般情况下,重要性指标越小的参数越容易被裁剪。
- 去除冗余参数:将网络中重要性指标较小的参数或连接去除,可以通过将它们的权重值设为零或者将其对应的连接删去来实现。通常情况下,裁剪的比例在20%~90%之间。
- 重新训练网络:对裁剪后的网络进行微调,以便保持其精度和泛化能力。
实操步骤
我们对较为经典的一个网络进行裁枝操作,在这里我们选择较为简单的网络作为模板进行操作。这里选择图像分类任务(步骤详见:# 【基础实操】借用torch自带网络进行训练自己的图像数据 ) 作为基础任务向大家展示:
ini
复制代码
import torch.nn as nn import torch.nn.utils.prune as prune class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, kernel_size=5) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(6, 16, kernel_size=5) self.relu2 = nn.ReLU(inplace=True) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.relu3 = nn.ReLU(inplace=True) self.fc2 = nn.Linear(120, 84) self.relu4 = nn.ReLU(inplace=True) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.conv1(x) x = self.relu1(x) x = self.conv2(x) x = self.relu2(x) x = x.view(-1, 16 * 5 * 5) x = self.fc1(x) x = self.relu3(x) x = self.fc2(x) x = self.relu4(x) x = self.fc3(x) return x # 定义一个网络实例 net = Net() # 对网络进行剪枝操作 parameters_to_prune = ( (net.conv1, 'weight'), (net.conv2, 'weight'), (net.fc1, 'weight'), (net.fc2, 'weight'), (net.fc3, 'weight'), ) prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, ) # 使用剪枝后的网络进行训练
对于上述网络中我们可以得到一个10分类的网络。
yolo 系列
在bubbliiiing博主的yolo系列的代码中,我们在train.py文件中找到创建yolo模型,在载入model后,对bubbliiiing博主的主干网络中的三个特征图进行压缩裁枝操作。下面是部分代码,可供参考。
ini
复制代码
#------------------------------------------------------# # 创建yolo模型 #------------------------------------------------------# model = YoloBody(anchors_mask, num_classes, phi, pretrained=pretrained) parameters_to_prune = ( (model.yolo_head_P3, 'weight'), (model.yolo_head_P3, 'weight'), (model.yolo_head_P3, 'weight'), ) prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, )