YOLOv4剪枝【附代码】

简介: 笔记

对于单层剪枝主要分以下步骤[怕自己翻译有差距,所以各位可自行翻译理解]:

30.png

大致意思就说对于你要剪的滤波器(就是卷积了),计算每个通道权重绝对值之和,就是L1,将计算结果进行排序,然后用一个最小值(可以理解为一个阈值)和这些卷积核相关的特征层的m个通道进行剪枝,剪枝以后会生成一个新的卷积核,将原来卷积核内没有被剪的权重赋值给现在得到的新卷积核。

31.png

多层剪枝(两种策略):


• Independent pruning determines which filters should be pruned at each layer independent of other layers.


• Greedy pruning accounts for the filters that have been removed in the previous layers. This strategy does not consider the kernels for the previously pruned feature maps while calculating the sum of absolute weights.


1. 独立剪枝:对每一层的滤波器(卷积核)决定是否剪枝与其他层无关


2.贪婪剪枝:考虑先前层中所移除的滤波器。在剪枝计算权重绝对值之和时,不需要考虑先前已剪枝的特征图所对应的核


对于策略1独立剪枝,核的绿色部分是计算输出通道绝对值之和时,不考虑之前移除的特征层的这一维度(蓝色部分),也就是只考虑绿色这一列通道之和(这里的个人理解:蓝色这一行在单层剪枝是直接去掉的,也就是图中黄色部分都没有了,但在多层的独立剪枝中,在ni+2方向上只把蓝色区域去除,但黄色的部分要保留),所以核内黄色区域权重保留!


在策略2的贪婪剪枝中,不去统计在特征层中准备剪掉的通道[即考虑先前层中移除的]【个人理解,就是这里又不考虑黄色部分了吧】【可以自己翻译理解下:The greedy pruning strategy does not count kernels for the already pruned feature maps.】


32.png

环境:


显卡:英伟达1650


pytorch 1.7.0(低版本应该也是可以的)


torchvision 0.8.0


torch_pruning


安装


pip install torch_pruning


导入包


import torch_pruning as tp


模型的实例化(针对已经训练好的模型)


model = torch.load('权重路径') model.eval()


对于非单层卷积的通道剪枝(不用看3.1)


剪枝之前统计一下模型的参数量

num_params_before_pruning = tp.utils.count_params(model)

1. setup strategy (L1 Norm) 计算每个通道的权重


strategy = tp.strategy.L1Strategy()

2.建立依赖图(与torch.jit很像)


DG = tp.DependencyGraph() 
DG = DG.build_dependency(model, example_inputs=torch.randn(1, 3, input_size[0], input_size[1])) # input_size是网络的输入大小

3.分情况(1.单个卷积进行剪枝 2.层进行剪枝)


3.1 单个卷积(会返回要剪枝的通道索引,这个通道是索引是根据L1正则得到的)

pruning_idxs = strategy(model.conv1.weight, amount=0.4) # model.conv1.weigth是对特定的卷积进行剪枝,amount是剪枝率

将根据依赖图收集所有受影响的层,将它们传播到整个图上,然后提供一个PruningPlan正确修剪模型的方法。


pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs ) pruning_plan.exec() 
torch.save(model, 'pru_model.pth')

3.2 层剪枝(需要筛选出不需要剪枝的层,比如yolo需要把头部的预测部分取出来,这个是不需要剪枝的)

excluded_layers = list(model.model[-1].modules()) 
for m in model.modules(): 
    if isinstance(m, nn.Conv2d) and m not in excluded_layers: 
        pruning_plan = DG.get_pruning_plan(m,tp.prune_conv, idxs=strategy(m.weight, amount=0.4)) 
        print(pruning_plan) # 执行剪枝 pruning_plan.exec()

如果想看一下剪枝以后的参数,可以运行:


num_params_after_pruning = tp.utils.count_params(model) 
print( " Params: %s => %s"%( num_params_before_pruning, num_params_after_pruning))

剪枝完以后模型的保存(不要用torch.save(model.state_dict(),...))


torch.save(model, 'pruning_model.pth')

在代码中的prunmodel.py是对yolov4剪枝代码


如果你的权重是用torch.save(model.state_dict())保存的,请重新加载模型后用torch.save(model)保存[或调用save_whole_model函数]


如果你需要对单独的一个卷积进行剪枝,可以调用Conv_pruning(模型权重路径[包含网络结构图]),然后在k处修改你要剪枝的某一个卷积


如果需要对某部分进行剪枝,可以调用layer_pruning()函数,included_layers是你想要剪枝的部分[我这里是对SPP后面的三个卷积剪枝,如果需要剪枝别的地方,需要修改list里面的参数,注意尽量不要对head部分剪枝]


预测部分


1.png

我尝试了一下对主干剪枝,发现精度损失严重,大家想剪哪部分可以自己去尝试,我只是把框架给搭建起来方便大家的使用,对最终的效果不保证,需要自己去炼丹。

可以训练自己的模型,剪枝后应该对模型进行一个重训练的微调提升准确率,这部分代码我还没有加入进去,可以自己把剪枝后的权重放训练代码中微调一下就行。后期有时间会加入微调训练部分。


image.png

剪枝后的预测结果


在coco数据集上,剪枝前后的参数量输出如下:

Params: 64363101 => 60438323

正确的剪枝后会打印出以下信息【如果有类似的信息就是可以正常剪枝,如果没出现,说明你剪的不对】:

[ <DEP: prune_conv => prune_conv on conv2.2.conv (Conv2d(615, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=125460]
[ <DEP: prune_conv => prune_batchnorm on conv2.2.bn (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=408]
[ <DEP: prune_batchnorm => _prune_elementwise_op on _ElementWiseOp()>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on upsample1.upsample.0.conv (Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[1, 2, 3, 4, 7, 9, 15, 16, 17, 18, 23, 24, 35, 39, 42, 45, 46, 53, 54, 58, 62, 63, 65, 70, 73, 75, 81, 84, 88, 90, 97, 101, 102, 106, 109, 111, 112, 113, 116, 119, 124, 127, 132, 133, 135, 138, 140, 141, 143, 145, 148, 149, 150, 157, 159, 161, 163, 166, 168, 169, 170, 172, 174, 176, 177, 179, 180, 181, 182, 185, 186, 187, 189, 191, 193, 196, 200, 202, 207, 210, 211, 216, 217, 220, 222, 223, 225, 226, 228, 233, 234, 236, 238, 242, 243, 245, 246, 247, 248, 249, 251, 253, 254, 255, 256, 257, 266, 267, 270, 271, 279, 280, 284, 287, 288, 291, 295, 299, 301, 302, 305, 306, 307, 309, 311, 314, 316, 319, 325, 326, 328, 331, 333, 341, 342, 344, 345, 346, 348, 349, 352, 355, 356, 357, 361, 371, 372, 374, 375, 379, 381, 385, 388, 391, 394, 395, 399, 401, 403, 406, 408, 410, 414, 415, 418, 421, 425, 430, 432, 433, 434, 438, 439, 440, 444, 445, 448, 449, 452, 453, 461, 465, 467, 468, 469, 473, 475, 479, 480, 481, 482, 485, 487, 492, 493, 495, 497, 499, 504, 505, 507, 508, 510, 511], NumPruned=52224]
[ <DEP: _prune_elementwise_op => _prune_concat on _ConcatOp([0, 512, 1024])>, Index=[513, 514, 515, 516, 519, 521, 527, 528, 529, 530, 535, 536, 547, 551, 554, 557, 558, 565, 566, 570, 574, 575, 577, 582, 585, 587, 593, 596, 600, 602, 609, 613, 614, 618, 621, 623, 624, 625, 628, 631, 636, 639, 644, 645, 647, 650, 652, 653, 655, 657, 660, 661, 662, 669, 671, 673, 675, 678, 680, 681, 682, 684, 686, 688, 689, 691, 692, 693, 694, 697, 698, 699, 701, 703, 705, 708, 712, 714, 719, 722, 723, 728, 729, 732, 734, 735, 737, 738, 740, 745, 746, 748, 750, 754, 755, 757, 758, 759, 760, 761, 763, 765, 766, 767, 768, 769, 778, 779, 782, 783, 791, 792, 796, 799, 800, 803, 807, 811, 813, 814, 817, 818, 819, 821, 823, 826, 828, 831, 837, 838, 840, 843, 845, 853, 854, 856, 857, 858, 860, 861, 864, 867, 868, 869, 873, 883, 884, 886, 887, 891, 893, 897, 900, 903, 906, 907, 911, 913, 915, 918, 920, 922, 926, 927, 930, 933, 937, 942, 944, 945, 946, 950, 951, 952, 956, 957, 960, 961, 964, 965, 973, 977, 979, 980, 981, 985, 987, 991, 992, 993, 994, 997, 999, 1004, 1005, 1007, 1009, 1011, 1016, 1017, 1019, 1020, 1022, 1023], NumPruned=0]
[ <DEP: _prune_concat => prune_related_conv on make_five_conv4.0.conv (Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[513, 514, 515, 516, 519, 521, 527, 528, 529, 530, 535, 536, 547, 551, 554, 557, 558, 565, 566, 570, 574, 575, 577, 582, 585, 587, 593, 596, 600, 602, 609, 613, 614, 618, 621, 623, 624, 625, 628, 631, 636, 639, 644, 645, 647, 650, 652, 653, 655, 657, 660, 661, 662, 669, 671, 673, 675, 678, 680, 681, 682, 684, 686, 688, 689, 691, 692, 693, 694, 697, 698, 699, 701, 703, 705, 708, 712, 714, 719, 722, 723, 728, 729, 732, 734, 735, 737, 738, 740, 745, 746, 748, 750, 754, 755, 757, 758, 759, 760, 761, 763, 765, 766, 767, 768, 769, 778, 779, 782, 783, 791, 792, 796, 799, 800, 803, 807, 811, 813, 814, 817, 818, 819, 821, 823, 826, 828, 831, 837, 838, 840, 843, 845, 853, 854, 856, 857, 858, 860, 861, 864, 867, 868, 869, 873, 883, 884, 886, 887, 891, 893, 897, 900, 903, 906, 907, 911, 913, 915, 918, 920, 922, 926, 927, 930, 933, 937, 942, 944, 945, 946, 950, 951, 952, 956, 957, 960, 961, 964, 965, 973, 977, 979, 980, 981, 985, 987, 991, 992, 993, 994, 997, 999, 1004, 1005, 1007, 1009, 1011, 1016, 1017, 1019, 1020, 1022, 1023], NumPruned=104448]
282540 parameters will be pruned
目录
相关文章
|
6月前
|
机器学习/深度学习 算法 计算机视觉
YOLOv3的算法原理是怎么样的
YOLOv3的算法原理是怎么样的
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
pytorch中的模型剪枝
pytorch中的模型剪枝
|
机器学习/深度学习 存储 缓存
YOLOv5的Tricks | 【Trick9】模型剪枝处理与Pytorch实现的剪枝策略
在yolov5项目中的torch_utils.py文件下,有prune这个函数,用来实现模型的剪枝处理。对模型裁剪,模型剪枝这方面之前没有接触到,这里用这篇笔记来学习记录一下这方面内容。
2227 0
YOLOv5的Tricks | 【Trick9】模型剪枝处理与Pytorch实现的剪枝策略
|
机器学习/深度学习 编解码 自然语言处理
基于EasyCV复现ViTDet:单层特征超越FPN
ViTDet其实是恺明团队MAE和ViT-based Mask R-CNN两个工作的延续。MAE提出了ViT的无监督训练方法,而ViT-based Mask R-CNN给出了用ViT作为backbone的Mask R-CNN的训练技巧,并证明了MAE预训练对下游检测任务的重要性。而ViTDet进一步改进了一些设计,证明了ViT作为backone的检测模型可以匹敌基于FPN的backbone(如SwinT和MViT)检测模型。
|
机器学习/深度学习 Dart 并行计算
暴力涨点 | IC-Conv使用高效空洞搜索Inception卷积带来全领域涨点(文末附论文下载)(二)
暴力涨点 | IC-Conv使用高效空洞搜索Inception卷积带来全领域涨点(文末附论文下载)(二)
129 0
|
机器学习/深度学习 Dart 算法
暴力涨点 | IC-Conv使用高效空洞搜索Inception卷积带来全领域涨点(文末附论文下载)(一)
暴力涨点 | IC-Conv使用高效空洞搜索Inception卷积带来全领域涨点(文末附论文下载)(一)
184 0
|
机器学习/深度学习 人工智能 算法
【Pytorch神经网络理论篇】 23 对抗神经网络:概述流程 + WGAN模型 + WGAN-gp模型 + 条件GAN + WGAN-div + W散度
GAN的原理与条件变分自编码神经网络的原理一样。这种做法可以理解为给GAN增加一个条件,让网络学习图片分布时加入标签因素,这样可以按照标签的数值来生成指定的图片。
628 0
|
数据可视化 PyTorch 算法框架/工具
|
算法 Linux 异构计算
|
XML 数据格式 索引