YOLOv4剪枝【附代码】

简介: 笔记

对于单层剪枝主要分以下步骤[怕自己翻译有差距,所以各位可自行翻译理解]:

30.png

大致意思就说对于你要剪的滤波器(就是卷积了),计算每个通道权重绝对值之和,就是L1,将计算结果进行排序,然后用一个最小值(可以理解为一个阈值)和这些卷积核相关的特征层的m个通道进行剪枝,剪枝以后会生成一个新的卷积核,将原来卷积核内没有被剪的权重赋值给现在得到的新卷积核。

31.png

多层剪枝(两种策略):


• Independent pruning determines which filters should be pruned at each layer independent of other layers.


• Greedy pruning accounts for the filters that have been removed in the previous layers. This strategy does not consider the kernels for the previously pruned feature maps while calculating the sum of absolute weights.


1. 独立剪枝:对每一层的滤波器(卷积核)决定是否剪枝与其他层无关


2.贪婪剪枝:考虑先前层中所移除的滤波器。在剪枝计算权重绝对值之和时,不需要考虑先前已剪枝的特征图所对应的核


对于策略1独立剪枝,核的绿色部分是计算输出通道绝对值之和时,不考虑之前移除的特征层的这一维度(蓝色部分),也就是只考虑绿色这一列通道之和(这里的个人理解:蓝色这一行在单层剪枝是直接去掉的,也就是图中黄色部分都没有了,但在多层的独立剪枝中,在ni+2方向上只把蓝色区域去除,但黄色的部分要保留),所以核内黄色区域权重保留!


在策略2的贪婪剪枝中,不去统计在特征层中准备剪掉的通道[即考虑先前层中移除的]【个人理解,就是这里又不考虑黄色部分了吧】【可以自己翻译理解下:The greedy pruning strategy does not count kernels for the already pruned feature maps.】


32.png

环境:


显卡:英伟达1650


pytorch 1.7.0(低版本应该也是可以的)


torchvision 0.8.0


torch_pruning


安装


pip install torch_pruning


导入包


import torch_pruning as tp


模型的实例化(针对已经训练好的模型)


model = torch.load('权重路径') model.eval()


对于非单层卷积的通道剪枝(不用看3.1)


剪枝之前统计一下模型的参数量

num_params_before_pruning = tp.utils.count_params(model)

1. setup strategy (L1 Norm) 计算每个通道的权重


strategy = tp.strategy.L1Strategy()

2.建立依赖图(与torch.jit很像)


DG = tp.DependencyGraph() 
DG = DG.build_dependency(model, example_inputs=torch.randn(1, 3, input_size[0], input_size[1])) # input_size是网络的输入大小

3.分情况(1.单个卷积进行剪枝 2.层进行剪枝)


3.1 单个卷积(会返回要剪枝的通道索引,这个通道是索引是根据L1正则得到的)

pruning_idxs = strategy(model.conv1.weight, amount=0.4) # model.conv1.weigth是对特定的卷积进行剪枝,amount是剪枝率

将根据依赖图收集所有受影响的层,将它们传播到整个图上,然后提供一个PruningPlan正确修剪模型的方法。


pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs ) pruning_plan.exec() 
torch.save(model, 'pru_model.pth')

3.2 层剪枝(需要筛选出不需要剪枝的层,比如yolo需要把头部的预测部分取出来,这个是不需要剪枝的)

excluded_layers = list(model.model[-1].modules()) 
for m in model.modules(): 
    if isinstance(m, nn.Conv2d) and m not in excluded_layers: 
        pruning_plan = DG.get_pruning_plan(m,tp.prune_conv, idxs=strategy(m.weight, amount=0.4)) 
        print(pruning_plan) # 执行剪枝 pruning_plan.exec()

如果想看一下剪枝以后的参数,可以运行:


num_params_after_pruning = tp.utils.count_params(model) 
print( " Params: %s => %s"%( num_params_before_pruning, num_params_after_pruning))

剪枝完以后模型的保存(不要用torch.save(model.state_dict(),...))


torch.save(model, 'pruning_model.pth')

在代码中的prunmodel.py是对yolov4剪枝代码


如果你的权重是用torch.save(model.state_dict())保存的,请重新加载模型后用torch.save(model)保存[或调用save_whole_model函数]


如果你需要对单独的一个卷积进行剪枝,可以调用Conv_pruning(模型权重路径[包含网络结构图]),然后在k处修改你要剪枝的某一个卷积


如果需要对某部分进行剪枝,可以调用layer_pruning()函数,included_layers是你想要剪枝的部分[我这里是对SPP后面的三个卷积剪枝,如果需要剪枝别的地方,需要修改list里面的参数,注意尽量不要对head部分剪枝]


预测部分


1.png

我尝试了一下对主干剪枝,发现精度损失严重,大家想剪哪部分可以自己去尝试,我只是把框架给搭建起来方便大家的使用,对最终的效果不保证,需要自己去炼丹。

可以训练自己的模型,剪枝后应该对模型进行一个重训练的微调提升准确率,这部分代码我还没有加入进去,可以自己把剪枝后的权重放训练代码中微调一下就行。后期有时间会加入微调训练部分。


image.png

剪枝后的预测结果


在coco数据集上,剪枝前后的参数量输出如下:

Params: 64363101 => 60438323

正确的剪枝后会打印出以下信息【如果有类似的信息就是可以正常剪枝,如果没出现,说明你剪的不对】:

[ <DEP: prune_conv => prune_conv on conv2.2.conv (Conv2d(615, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=125460]
[ <DEP: prune_conv => prune_batchnorm on conv2.2.bn (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=408]
[ <DEP: prune_batchnorm => _prune_elementwise_op on _ElementWiseOp()>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on upsample1.upsample.0.conv (Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=52224]
[ <DEP: _prune_elementwise_op => _prune_concat on _ConcatOp([0, 512, 1024])>, Index=[513, 514, 515, 516, 519, 521, 527, 528, 529, 530, 535, 536, 547, 551, 554, 557, 558, 565, 566, 570, 574, 575, 577, 582, 585, 587, 593, 596, 600, 602, 609, 613, 614, 618, 621, 623, 624, 625, 628, 631, 636, 639, 644, 645, 647, 650, 652, 653, 655, 657, 660, 661, 662, 669, 671, 673, 675, 678, 680, 681, 682, 684, 686, 688, 689, 691, 692, 693, 694, 697, 698, 699, 701, 703, 705, 708, 712, 714, 719, 722, 723, 728, 729, 732, 734, 735, 737, 738, 740, 745, 746, 748, 750, 754, 755, 757, 758, 759, 760, 761, 763, 765, 766, 767, 768, 769, 778, 779, 782, 783, 791, 792, 796, 799, 800, 803, 807, 811, 813, 814, 817, 818, 819, 821, 823, 826, 828, 831, 837, 838, 840, 843, 845, 853, 854, 856, 857, 858, 860, 861, 864, 867, 868, 869, 873, 883, 884, 886, 887, 891, 893, 897, 900, 903, 906, 907, 911, 913, 915, 918, 920, 922, 926, 927, 930, 933, 937, 942, 944, 945, 946, 950, 951, 952, 956, 957, 960, 961, 964, 965, 973, 977, 979, 980, 981, 985, 987, 991, 992, 993, 994, 997, 999, 1004, 1005, 1007, 1009, 1011, 1016, 1017, 1019, 1020, 1022, 1023], NumPruned=0]
[ <DEP: _prune_concat => prune_related_conv on make_five_conv4.0.conv (Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[513, 514, 515, 516, 519, 521, 527, 528, 529, 530, 535, 536, 547, 551, 554, 557, 558, 565, 566, 570, 574, 575, 577, 582, 585, 587, 593, 596, 600, 602, 609, 613, 614, 618, 621, 623, 624, 625, 628, 631, 636, 639, 644, 645, 647, 650, 652, 653, 655, 657, 660, 661, 662, 669, 671, 673, 675, 678, 680, 681, 682, 684, 686, 688, 689, 691, 692, 693, 694, 697, 698, 699, 701, 703, 705, 708, 712, 714, 719, 722, 723, 728, 729, 732, 734, 735, 737, 738, 740, 745, 746, 748, 750, 754, 755, 757, 758, 759, 760, 761, 763, 765, 766, 767, 768, 769, 778, 779, 782, 783, 791, 792, 796, 799, 800, 803, 807, 811, 813, 814, 817, 818, 819, 821, 823, 826, 828, 831, 837, 838, 840, 843, 845, 853, 854, 856, 857, 858, 860, 861, 864, 867, 868, 869, 873, 883, 884, 886, 887, 891, 893, 897, 900, 903, 906, 907, 911, 913, 915, 918, 920, 922, 926, 927, 930, 933, 937, 942, 944, 945, 946, 950, 951, 952, 956, 957, 960, 961, 964, 965, 973, 977, 979, 980, 981, 985, 987, 991, 992, 993, 994, 997, 999, 1004, 1005, 1007, 1009, 1011, 1016, 1017, 1019, 1020, 1022, 1023], NumPruned=104448]
282540 parameters will be pruned
目录
相关文章
|
6月前
|
弹性计算 人工智能 自然语言处理
产品动态丨阿里云计算巢月刊-2025年第5期
让优秀的企业软件生于云、长于云
|
SQL 自然语言处理 安全
2024 年 8 月暨 ACL 2024 57篇代码大模型论文精选
2024年8月中旬,国际计算语言学大会ACL在泰国曼谷举行,展示了48篇代码大模型相关论文,包括24篇主会论文和24篇findings论文。主会论文涵盖XFT、WaveCoder、DolphCoder等创新方法,findings论文则探讨了代码注释增强、自动化程序修复等主题。此外,还额外整理了9篇8月最新代码大模型论文,涉及数据集合成、安全代码生成等多个前沿方向。欲了解更多,请访问我们的综述和GitHub项目。
1394 4
|
算法 数据挖掘
Faiss为啥这么快?原来是量化器在做怪!2
Faiss为啥这么快?原来是量化器在做怪!
860 1
|
移动开发 前端开发 JavaScript
UniApp H5项目大揭秘:高效生成与扫描二维码的终极策略,让你的应用脱颖而出!
【8月更文挑战第3天】UniApp让开发者能以Vue.js构建跨平台应用。在H5项目中,通过第三方库如qrcodejs2可轻松生成二维码,代码简洁易集成;或用Canvas API获得更高灵活性。扫描方面,H5+ API适合App环境,而纯H5项目则需前端库加后端服务配合。不同方法各有优势,应按需选择以优化体验。
826 0
|
存储 缓存 C++
GGML 非官方中文文档(1)
GGML 非官方中文文档
657 1
|
机器学习/深度学习 存储 算法
Faiss为啥这么快?原来是量化器在做怪!1
Faiss为啥这么快?原来是量化器在做怪!
1155 0
STM32:GPIO控制LED闪烁代码部分(内含配置图+代码+代码注释)
STM32:GPIO控制LED闪烁代码部分(内含配置图+代码+代码注释)
997 0
STM32:GPIO控制LED闪烁代码部分(内含配置图+代码+代码注释)
|
存储 异构计算 索引
GGML 非官方中文文档(3)
GGML 非官方中文文档
463 0
|
程序员 Python
利用Python实现科学式占卜
一直以来,中式占卜都是基于算命先生手工实现,程序繁琐(往往需要沐浴、计算天时、静心等等流程)。准备工作复杂(通常需要铜钱等道具),计算方法复杂,需要纯手工计算二进制并转换为最终的卦象,为了解决这个问题,笔者基于python实现了一套科学算命工具,用于快速进行占卜。 本文的算命方式采用八卦 + 周易+ 梅花易数实现,脚本基于python3.9.0开发。本人对于周易五行研究较浅,如有疏漏请见谅。 最终效果如图,在运行程序之后,会根据当前的运势自动获取你心中所想之事的卦象(本卦、互卦、变卦) 前置知识 基础原理 首先我们需要了解一些最基本的占卜知识,目前我国几种比较主流的占卜方式基本都是基
349 0
|
存储 Kubernetes Docker
Docker容器编排技术解析
Docker容器编排技术解析
790 0