pytorch利用hook【钩子】获取torch网络每层结构【附代码】

简介: 笔记

写本文的目的是为了方便在剪枝中或其他应用中获取网络结构,如何有效的利用hook获取每层的结构来判断是否可以剪枝。


要对网络进行trace,或者获取网络结构,需要知道“grad_fn”。我们知道在pytorch中导数对应的关键词为“grad”。对一个变量我们可以设置requires_grad为True或者False来设置该变量是否求偏导。


grad_fn


grad_fn: grad_fn用来记录变量变化的过程,方便计算梯度,比如:y = x*2,grad_fn记录了y由x计算的过程。

这里举个例子:设置一个x,并设置其可求导,就也是后面要对他求偏导。

x = torch.ones(2,2, requires_grad=True)
x
Out[4]: 
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

当我们输出x的grad_fn时输出为None,这是因为这里的tensor x是直接给出的,他并没有经过任何的运算。


print(x.grad_fn)
None

当我们设置一个简单的二次函数y=2 * x,可以得到如下结果,可以看到grad_fn现在显示的MulBackward0,意思就是用了乘法。


y = 2*x
Out[10]: 
tensor([[2., 2.],
        [2., 2.]], grad_fn=<MulBackward0>)

hook获取网络结构


通过理解了grad_fn,那么我们就可以对网络进行trace,获取每层的网络结构了。


这里以YOLOv5s为例。先附上代码,PRUNABLE_MODULES列表放了Conv,BN以及PReLu。


grad_fn_to_module字典是用来通过grad_fn获取网络每层的结构,也就是如果grad_fn不为None的时候就放入字典中。


visited用来记录每层出现的次数。


这里会用到一个关键的函数:register_forward_hook。


该函数的作用是在不改变torch网络的情况下获取每层的输出。该方法需要传入一个func,其中包含module,inputs,outputs。也就是我下面代码中定义的_record_module_grad_fn。


import torch
import torch.nn as nn
PRUNABLE_MODULES = [ nn.modules.conv._ConvNd, nn.modules.batchnorm._BatchNorm, nn.Linear, nn.PReLU]
grad_fn_to_module = {}  # 如果获取不到是无法剪枝的
visited = {}  # visited会记录每层出现的次数
def _record_module_grad_fn(module, inputs, outputs): # 记录model的grad_fn
    if module not in visited:
        visited[module] = 1
    else:
        visited[module] += 1
    grad_fn_to_module[outputs.grad_fn] = module
model = torch.load('../runs/train/exp/weights/best.pt')['model'].float().cpu()
for para in model.parameters():
    para.requires_grad = True
x = torch.ones(1, 3, 640, 640)
for m in model.modules():
    if isinstance(m, tuple(PRUNABLE_MODULES)):
        hooks = [m.register_forward_hook(_record_module_grad_fn)]
out = model(x)
for hook in hooks:
    hook.remove()
print(grad_fn_to_module)


这里需要注意:在代码运行到out = model(x)之前的过程中,grad_fn_to_module字典一直为空。通过debug也可以看到。

45.png


但是!!当我用样例x将我的mode跑了一遍获得out的时候,此刻grad_fn_to_module就开始将网络从头到尾开始记录了。该字典内容如下,可以看到针对第一个key为Convolution操作,所以记录下了Conv2d(3,32,.....)这一层,后面都是如此。


{: Conv2d(3, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2), bias=False), : BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), : BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False),



46.png

通过上面的方法我们就可以通过hook获取网络的每层结构。这个就可以用来做剪枝操作。


注意!在我的代码中可以看到有这么一行:


for para in model.parameters():
    para.requires_grad = True

我这里将模型的所有参数均设置为可导的,为什么要这里设置呢,这是因为我在对官方代码yolov5 6.0代码剪枝的时候,发现backbone无法剪枝,比如我想对第一层进行剪枝,会给我报KeyError的错误,最后通过仔细研究发现,在官方提供的v5模型中backbone的grad_fn均为None,利用hook无法获得网络,只能获得head部分的结构,下面显示是backbone的grad_fn为None记录的结构,:解决的办法也很简单,就是加入我上面的代码,并设置参数可导即可。


{None: BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), :  


你学会了吗~

目录
相关文章
|
3月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
|
3月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
152 0
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
|
7月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
|
7月前
|
机器学习/深度学习 自然语言处理 PyTorch
基于Pytorch Gemotric在昇腾上实现GAT图神经网络
本实验基于昇腾平台,使用PyTorch实现图神经网络GAT(Graph Attention Networks)在Pubmed数据集上的分类任务。内容涵盖GAT网络的创新点分析、图注意力机制原理、多头注意力机制详解以及模型代码实战。实验通过两层GAT网络对Pubmed数据集进行训练,验证模型性能,并展示NPU上的内存使用情况。最终,模型在测试集上达到约36.60%的准确率。
|
7月前
|
算法 PyTorch 算法框架/工具
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
|
7月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本实验基于PyTorch Geometric,在昇腾平台上实现GraphSAGE图神经网络,使用CiteSeer数据集进行分类训练。内容涵盖GraphSAGE的创新点、算法原理、网络架构及实战分析。GraphSAGE通过采样和聚合节点邻居特征,支持归纳式学习,适用于未见节点的表征生成。实验包括模型搭建、训练与验证,并在NPU上运行,最终测试准确率达0.665。
|
3月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
210 1
|
7月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
1074 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
3月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
171 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节

热门文章

最新文章

推荐镜像

更多