YOLOv4剪枝【附代码】

简介: 笔记

对于单层剪枝主要分以下步骤[怕自己翻译有差距,所以各位可自行翻译理解]:

30.png

大致意思就说对于你要剪的滤波器(就是卷积了),计算每个通道权重绝对值之和,就是L1,将计算结果进行排序,然后用一个最小值(可以理解为一个阈值)和这些卷积核相关的特征层的m个通道进行剪枝,剪枝以后会生成一个新的卷积核,将原来卷积核内没有被剪的权重赋值给现在得到的新卷积核。

31.png

多层剪枝(两种策略):


• Independent pruning determines which filters should be pruned at each layer independent of other layers.


• Greedy pruning accounts for the filters that have been removed in the previous layers. This strategy does not consider the kernels for the previously pruned feature maps while calculating the sum of absolute weights.


1. 独立剪枝:对每一层的滤波器(卷积核)决定是否剪枝与其他层无关


2.贪婪剪枝:考虑先前层中所移除的滤波器。在剪枝计算权重绝对值之和时,不需要考虑先前已剪枝的特征图所对应的核


对于策略1独立剪枝,核的绿色部分是计算输出通道绝对值之和时,不考虑之前移除的特征层的这一维度(蓝色部分),也就是只考虑绿色这一列通道之和(这里的个人理解:蓝色这一行在单层剪枝是直接去掉的,也就是图中黄色部分都没有了,但在多层的独立剪枝中,在ni+2方向上只把蓝色区域去除,但黄色的部分要保留),所以核内黄色区域权重保留!


在策略2的贪婪剪枝中,不去统计在特征层中准备剪掉的通道[即考虑先前层中移除的]【个人理解,就是这里又不考虑黄色部分了吧】【可以自己翻译理解下:The greedy pruning strategy does not count kernels for the already pruned feature maps.】


32.png

环境:


显卡:英伟达1650


pytorch 1.7.0(低版本应该也是可以的)


torchvision 0.8.0


torch_pruning


安装


pip install torch_pruning


导入包


import torch_pruning as tp


模型的实例化(针对已经训练好的模型)


model = torch.load('权重路径') model.eval()


对于非单层卷积的通道剪枝(不用看3.1)


剪枝之前统计一下模型的参数量

num_params_before_pruning = tp.utils.count_params(model)

1. setup strategy (L1 Norm) 计算每个通道的权重


strategy = tp.strategy.L1Strategy()

2.建立依赖图(与torch.jit很像)


DG = tp.DependencyGraph() 
DG = DG.build_dependency(model, example_inputs=torch.randn(1, 3, input_size[0], input_size[1])) # input_size是网络的输入大小

3.分情况(1.单个卷积进行剪枝 2.层进行剪枝)


3.1 单个卷积(会返回要剪枝的通道索引,这个通道是索引是根据L1正则得到的)

pruning_idxs = strategy(model.conv1.weight, amount=0.4) # model.conv1.weigth是对特定的卷积进行剪枝,amount是剪枝率

将根据依赖图收集所有受影响的层,将它们传播到整个图上,然后提供一个PruningPlan正确修剪模型的方法。


pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs ) pruning_plan.exec() 
torch.save(model, 'pru_model.pth')

3.2 层剪枝(需要筛选出不需要剪枝的层,比如yolo需要把头部的预测部分取出来,这个是不需要剪枝的)

excluded_layers = list(model.model[-1].modules()) 
for m in model.modules(): 
    if isinstance(m, nn.Conv2d) and m not in excluded_layers: 
        pruning_plan = DG.get_pruning_plan(m,tp.prune_conv, idxs=strategy(m.weight, amount=0.4)) 
        print(pruning_plan) # 执行剪枝 pruning_plan.exec()

如果想看一下剪枝以后的参数,可以运行:


num_params_after_pruning = tp.utils.count_params(model) 
print( " Params: %s => %s"%( num_params_before_pruning, num_params_after_pruning))

剪枝完以后模型的保存(不要用torch.save(model.state_dict(),...))


torch.save(model, 'pruning_model.pth')

在代码中的prunmodel.py是对yolov4剪枝代码


如果你的权重是用torch.save(model.state_dict())保存的,请重新加载模型后用torch.save(model)保存[或调用save_whole_model函数]


如果你需要对单独的一个卷积进行剪枝,可以调用Conv_pruning(模型权重路径[包含网络结构图]),然后在k处修改你要剪枝的某一个卷积


如果需要对某部分进行剪枝,可以调用layer_pruning()函数,included_layers是你想要剪枝的部分[我这里是对SPP后面的三个卷积剪枝,如果需要剪枝别的地方,需要修改list里面的参数,注意尽量不要对head部分剪枝]


预测部分


1.png

我尝试了一下对主干剪枝,发现精度损失严重,大家想剪哪部分可以自己去尝试,我只是把框架给搭建起来方便大家的使用,对最终的效果不保证,需要自己去炼丹。

可以训练自己的模型,剪枝后应该对模型进行一个重训练的微调提升准确率,这部分代码我还没有加入进去,可以自己把剪枝后的权重放训练代码中微调一下就行。后期有时间会加入微调训练部分。


image.png

剪枝后的预测结果


在coco数据集上,剪枝前后的参数量输出如下:

Params: 64363101 => 60438323

正确的剪枝后会打印出以下信息【如果有类似的信息就是可以正常剪枝,如果没出现,说明你剪的不对】:

[ <DEP: prune_conv => prune_conv on conv2.2.conv (Conv2d(615, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=125460]
[ <DEP: prune_conv => prune_batchnorm on conv2.2.bn (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=408]
[ <DEP: prune_batchnorm => _prune_elementwise_op on _ElementWiseOp()>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on upsample1.upsample.0.conv (Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=52224]
[ <DEP: _prune_elementwise_op => _prune_concat on _ConcatOp([0, 512, 1024])>, Index=[513, 514, 515, 516, 519, 521, 527, 528, 529, 530, 535, 536, 547, 551, 554, 557, 558, 565, 566, 570, 574, 575, 577, 582, 585, 587, 593, 596, 600, 602, 609, 613, 614, 618, 621, 623, 624, 625, 628, 631, 636, 639, 644, 645, 647, 650, 652, 653, 655, 657, 660, 661, 662, 669, 671, 673, 675, 678, 680, 681, 682, 684, 686, 688, 689, 691, 692, 693, 694, 697, 698, 699, 701, 703, 705, 708, 712, 714, 719, 722, 723, 728, 729, 732, 734, 735, 737, 738, 740, 745, 746, 748, 750, 754, 755, 757, 758, 759, 760, 761, 763, 765, 766, 767, 768, 769, 778, 779, 782, 783, 791, 792, 796, 799, 800, 803, 807, 811, 813, 814, 817, 818, 819, 821, 823, 826, 828, 831, 837, 838, 840, 843, 845, 853, 854, 856, 857, 858, 860, 861, 864, 867, 868, 869, 873, 883, 884, 886, 887, 891, 893, 897, 900, 903, 906, 907, 911, 913, 915, 918, 920, 922, 926, 927, 930, 933, 937, 942, 944, 945, 946, 950, 951, 952, 956, 957, 960, 961, 964, 965, 973, 977, 979, 980, 981, 985, 987, 991, 992, 993, 994, 997, 999, 1004, 1005, 1007, 1009, 1011, 1016, 1017, 1019, 1020, 1022, 1023], NumPruned=0]
[ <DEP: _prune_concat => prune_related_conv on make_five_conv4.0.conv (Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[513, 514, 515, 516, 519, 521, 527, 528, 529, 530, 535, 536, 547, 551, 554, 557, 558, 565, 566, 570, 574, 575, 577, 582, 585, 587, 593, 596, 600, 602, 609, 613, 614, 618, 621, 623, 624, 625, 628, 631, 636, 639, 644, 645, 647, 650, 652, 653, 655, 657, 660, 661, 662, 669, 671, 673, 675, 678, 680, 681, 682, 684, 686, 688, 689, 691, 692, 693, 694, 697, 698, 699, 701, 703, 705, 708, 712, 714, 719, 722, 723, 728, 729, 732, 734, 735, 737, 738, 740, 745, 746, 748, 750, 754, 755, 757, 758, 759, 760, 761, 763, 765, 766, 767, 768, 769, 778, 779, 782, 783, 791, 792, 796, 799, 800, 803, 807, 811, 813, 814, 817, 818, 819, 821, 823, 826, 828, 831, 837, 838, 840, 843, 845, 853, 854, 856, 857, 858, 860, 861, 864, 867, 868, 869, 873, 883, 884, 886, 887, 891, 893, 897, 900, 903, 906, 907, 911, 913, 915, 918, 920, 922, 926, 927, 930, 933, 937, 942, 944, 945, 946, 950, 951, 952, 956, 957, 960, 961, 964, 965, 973, 977, 979, 980, 981, 985, 987, 991, 992, 993, 994, 997, 999, 1004, 1005, 1007, 1009, 1011, 1016, 1017, 1019, 1020, 1022, 1023], NumPruned=104448]
282540 parameters will be pruned
目录
打赏
0
0
0
0
78
分享
相关文章
RocketMQ原理—4.消息读写的性能优化
本文详细解析了RocketMQ消息队列的核心原理与性能优化机制,涵盖Producer消息分发、Broker高并发写入、Consumer拉取消息流程等内容。重点探讨了基于队列的消息分发、Hash有序分发、CommitLog内存写入优化、ConsumeQueue物理存储设计等关键技术点。同时分析了数据丢失场景及解决方案,如同步刷盘与JVM OffHeap缓存分离策略,并总结了写入与读取流程的性能优化方法,为理解和优化分布式消息系统提供了全面指导。
RocketMQ原理—4.消息读写的性能优化
YOLO落地部署 | 一文全览YOLOv5最新的剪枝、量化的进展【必读】
YOLO落地部署 | 一文全览YOLOv5最新的剪枝、量化的进展【必读】
1728 0
2024 年 8 月暨 ACL 2024 57篇代码大模型论文精选
2024年8月中旬,国际计算语言学大会ACL在泰国曼谷举行,展示了48篇代码大模型相关论文,包括24篇主会论文和24篇findings论文。主会论文涵盖XFT、WaveCoder、DolphCoder等创新方法,findings论文则探讨了代码注释增强、自动化程序修复等主题。此外,还额外整理了9篇8月最新代码大模型论文,涉及数据集合成、安全代码生成等多个前沿方向。欲了解更多,请访问我们的综述和GitHub项目。
1031 4
【YOLOv8改进 - 注意力机制】NAM:基于归一化的注意力模块,将权重稀疏惩罚应用于注意力机制中,提高效率性能
**NAM: 提升模型效率的新颖归一化注意力模块,抑制非显著权重,结合通道和空间注意力,通过批量归一化衡量重要性。在Resnet和Mobilenet上的实验显示优于其他三种机制。源码见[GitHub](https://github.com/Christian-lyc/NAM)。**
使用Python进行曲线拟合:利用贝塞尔曲线
使用Python进行曲线拟合:利用贝塞尔曲线
569 1
YOLOv8模型yaml结构图理解(逐层分析)
YOLOv8模型yaml结构图理解(逐层分析)
15280 0
【Python 机器学习专栏】数据缺失值处理与插补方法
【4月更文挑战第30天】本文探讨了Python中处理数据缺失值的方法。缺失值影响数据分析和模型训练,可能导致模型偏差、准确性降低和干扰分析。检测缺失值可使用Pandas的`isnull()`和`notnull()`,或通过可视化。处理方法包括删除含缺失值的行/列及填充:固定值、均值/中位数、众数或最近邻。Scikit-learn提供了SimpleImputer和IterativeImputer类进行插补。选择方法要考虑数据特点、缺失值比例和模型需求。注意过度插补和验证评估。处理缺失值是提升数据质量和模型准确性关键步骤。
963 0
计算机网络:可靠数据传输(rdt)、流水协议、窗口滑动协议
计算机网络:可靠数据传输(rdt)、流水协议、窗口滑动协议
1122 2
Linux下采集摄像头的图像再保存为JPG图片存放到本地(YUYV转JPG)
Linux下采集摄像头的图像再保存为JPG图片存放到本地(YUYV转JPG)
2193 1
Linux下采集摄像头的图像再保存为JPG图片存放到本地(YUYV转JPG)
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等

登录插画

登录以查看您的控制台资源

管理云资源
状态一览
快捷访问