pytorch利用hook【钩子】获取torch网络每层结构【附代码】

简介: 笔记

写本文的目的是为了方便在剪枝中或其他应用中获取网络结构,如何有效的利用hook获取每层的结构来判断是否可以剪枝。


要对网络进行trace,或者获取网络结构,需要知道“grad_fn”。我们知道在pytorch中导数对应的关键词为“grad”。对一个变量我们可以设置requires_grad为True或者False来设置该变量是否求偏导。


grad_fn


grad_fn: grad_fn用来记录变量变化的过程,方便计算梯度,比如:y = x*2,grad_fn记录了y由x计算的过程。

这里举个例子:设置一个x,并设置其可求导,就也是后面要对他求偏导。

x = torch.ones(2,2, requires_grad=True)
x
Out[4]: 
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

当我们输出x的grad_fn时输出为None,这是因为这里的tensor x是直接给出的,他并没有经过任何的运算。


print(x.grad_fn)
None

当我们设置一个简单的二次函数y=2 * x,可以得到如下结果,可以看到grad_fn现在显示的MulBackward0,意思就是用了乘法。


y = 2*x
Out[10]: 
tensor([[2., 2.],
        [2., 2.]], grad_fn=<MulBackward0>)

hook获取网络结构


通过理解了grad_fn,那么我们就可以对网络进行trace,获取每层的网络结构了。


这里以YOLOv5s为例。先附上代码,PRUNABLE_MODULES列表放了Conv,BN以及PReLu。


grad_fn_to_module字典是用来通过grad_fn获取网络每层的结构,也就是如果grad_fn不为None的时候就放入字典中。


visited用来记录每层出现的次数。


这里会用到一个关键的函数:register_forward_hook。


该函数的作用是在不改变torch网络的情况下获取每层的输出。该方法需要传入一个func,其中包含module,inputs,outputs。也就是我下面代码中定义的_record_module_grad_fn。


import torch
import torch.nn as nn
PRUNABLE_MODULES = [ nn.modules.conv._ConvNd, nn.modules.batchnorm._BatchNorm, nn.Linear, nn.PReLU]
grad_fn_to_module = {}  # 如果获取不到是无法剪枝的
visited = {}  # visited会记录每层出现的次数
def _record_module_grad_fn(module, inputs, outputs): # 记录model的grad_fn
    if module not in visited:
        visited[module] = 1
    else:
        visited[module] += 1
    grad_fn_to_module[outputs.grad_fn] = module
model = torch.load('../runs/train/exp/weights/best.pt')['model'].float().cpu()
for para in model.parameters():
    para.requires_grad = True
x = torch.ones(1, 3, 640, 640)
for m in model.modules():
    if isinstance(m, tuple(PRUNABLE_MODULES)):
        hooks = [m.register_forward_hook(_record_module_grad_fn)]
out = model(x)
for hook in hooks:
    hook.remove()
print(grad_fn_to_module)


这里需要注意:在代码运行到out = model(x)之前的过程中,grad_fn_to_module字典一直为空。通过debug也可以看到。

45.png


但是!!当我用样例x将我的mode跑了一遍获得out的时候,此刻grad_fn_to_module就开始将网络从头到尾开始记录了。该字典内容如下,可以看到针对第一个key为Convolution操作,所以记录下了Conv2d(3,32,.....)这一层,后面都是如此。


{: Conv2d(3, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2), bias=False), : BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), : BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False),



46.png

通过上面的方法我们就可以通过hook获取网络的每层结构。这个就可以用来做剪枝操作。


注意!在我的代码中可以看到有这么一行:


for para in model.parameters():
    para.requires_grad = True

我这里将模型的所有参数均设置为可导的,为什么要这里设置呢,这是因为我在对官方代码yolov5 6.0代码剪枝的时候,发现backbone无法剪枝,比如我想对第一层进行剪枝,会给我报KeyError的错误,最后通过仔细研究发现,在官方提供的v5模型中backbone的grad_fn均为None,利用hook无法获得网络,只能获得head部分的结构,下面显示是backbone的grad_fn为None记录的结构,:解决的办法也很简单,就是加入我上面的代码,并设置参数可导即可。


{None: BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), :  


你学会了吗~

目录
相关文章
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 中的动态计算图:实现灵活的神经网络架构
【8月更文第27天】PyTorch 是一款流行的深度学习框架,它以其灵活性和易用性而闻名。与 TensorFlow 等其他框架相比,PyTorch 最大的特点之一是支持动态计算图。这意味着开发者可以在运行时定义网络结构,这为构建复杂的模型提供了极大的便利。本文将深入探讨 PyTorch 中动态计算图的工作原理,并通过一些示例代码展示如何利用这一特性来构建灵活的神经网络架构。
284 1
|
3月前
|
机器学习/深度学习 资源调度 自然语言处理
不同类型的循环神经网络结构
【8月更文挑战第16天】
50 0
|
29天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLO11改进 - C3k2融合】C3k2融合YOLO-MS的MSBlock : 分层特征融合策略,轻量化网络结构
【YOLO11改进 - C3k2融合】C3k2融合YOLO-MS的MSBlock : 分层特征融合策略,轻量化网络结构
|
1月前
|
机器学习/深度学习 存储 数据可视化
以pytorch的forward hook为例探究hook机制
【10月更文挑战第10天】PyTorch 的 Hook 机制允许用户在不修改模型代码的情况下介入前向和反向传播过程,适用于模型可视化、特征提取及梯度分析等任务。通过注册 `forward hook`,可以在模型前向传播过程中插入自定义操作,如记录中间层输出。使用时需注意输入输出格式及计算资源占用。
|
1月前
|
机器学习/深度学习 监控 PyTorch
以pytorch的forward hook为例探究hook机制
【10月更文挑战第9天】PyTorch中的Hook机制允许在不修改模型代码的情况下,获取和修改模型中间层的信息,如输入和输出等,适用于模型可视化、特征提取及梯度计算。Forward Hook在前向传播后触发,可用于特征提取和模型监控。实现上,需定义接收模块、输入和输出参数的Hook函数,并将其注册到目标层。与Backward Hook相比,前者关注前向传播,后者侧重反向传播和梯度处理,两者共同增强了对模型内部运行情况的理解和控制。
|
3月前
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
193 59
|
1月前
|
机器学习/深度学习 算法
神经网络的结构与功能
神经网络是一种广泛应用于机器学习和深度学习的模型,旨在模拟人类大脑的信息处理方式。它们由多层不同类型的节点或“神经元”组成,每层都有特定的功能和责任。
39 0
|
1月前
|
PyTorch 算法框架/工具 Python
Pytorch学习笔记(十):Torch对张量的计算、Numpy对数组的计算、它们之间的转换
这篇文章是关于PyTorch张量和Numpy数组的计算方法及其相互转换的详细学习笔记。
35 0
|
2月前
|
机器学习/深度学习
小土堆-pytorch-神经网络-损失函数与反向传播_笔记
在使用损失函数时,关键在于匹配输入和输出形状。例如,在L1Loss中,输入形状中的N代表批量大小。以下是具体示例:对于相同形状的输入和目标张量,L1Loss默认计算差值并求平均;此外,均方误差(MSE)也是常用损失函数。实战中,损失函数用于计算模型输出与真实标签间的差距,并通过反向传播更新模型参数。
|
2月前
|
编解码 人工智能 文件存储
卷积神经网络架构:EfficientNet结构的特点
EfficientNet是一种高效的卷积神经网络架构,它通过系统化的方法来提升模型的性能和效率。
60 1
下一篇
无影云桌面