写本文的目的是为了方便在剪枝中或其他应用中获取网络结构,如何有效的利用hook获取每层的结构来判断是否可以剪枝。
要对网络进行trace,或者获取网络结构,需要知道“grad_fn”。我们知道在pytorch中导数对应的关键词为“grad”。对一个变量我们可以设置requires_grad为True或者False来设置该变量是否求偏导。
grad_fn
grad_fn: grad_fn用来记录变量变化的过程,方便计算梯度,比如:y = x*2,grad_fn记录了y由x计算的过程。
这里举个例子:设置一个x,并设置其可求导,就也是后面要对他求偏导。
x = torch.ones(2,2, requires_grad=True) x Out[4]: tensor([[1., 1.], [1., 1.]], requires_grad=True)
当我们输出x的grad_fn时输出为None,这是因为这里的tensor x是直接给出的,他并没有经过任何的运算。
print(x.grad_fn) None
当我们设置一个简单的二次函数y=2 * x,可以得到如下结果,可以看到grad_fn现在显示的MulBackward0,意思就是用了乘法。
y = 2*x Out[10]: tensor([[2., 2.], [2., 2.]], grad_fn=<MulBackward0>)
hook获取网络结构
通过理解了grad_fn,那么我们就可以对网络进行trace,获取每层的网络结构了。
这里以YOLOv5s为例。先附上代码,PRUNABLE_MODULES列表放了Conv,BN以及PReLu。
grad_fn_to_module字典是用来通过grad_fn获取网络每层的结构,也就是如果grad_fn不为None的时候就放入字典中。
visited用来记录每层出现的次数。
这里会用到一个关键的函数:register_forward_hook。
该函数的作用是在不改变torch网络的情况下获取每层的输出。该方法需要传入一个func,其中包含module,inputs,outputs。也就是我下面代码中定义的_record_module_grad_fn。
import torch import torch.nn as nn PRUNABLE_MODULES = [ nn.modules.conv._ConvNd, nn.modules.batchnorm._BatchNorm, nn.Linear, nn.PReLU] grad_fn_to_module = {} # 如果获取不到是无法剪枝的 visited = {} # visited会记录每层出现的次数 def _record_module_grad_fn(module, inputs, outputs): # 记录model的grad_fn if module not in visited: visited[module] = 1 else: visited[module] += 1 grad_fn_to_module[outputs.grad_fn] = module model = torch.load('../runs/train/exp/weights/best.pt')['model'].float().cpu() for para in model.parameters(): para.requires_grad = True x = torch.ones(1, 3, 640, 640) for m in model.modules(): if isinstance(m, tuple(PRUNABLE_MODULES)): hooks = [m.register_forward_hook(_record_module_grad_fn)] out = model(x) for hook in hooks: hook.remove() print(grad_fn_to_module)
这里需要注意:在代码运行到out = model(x)之前的过程中,grad_fn_to_module字典一直为空。通过debug也可以看到。
但是!!当我用样例x将我的mode跑了一遍获得out的时候,此刻grad_fn_to_module就开始将网络从头到尾开始记录了。该字典内容如下,可以看到针对第一个key为Convolution操作,所以记录下了Conv2d(3,32,.....)这一层,后面都是如此。
{: Conv2d(3, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2), bias=False), : BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), : BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False),
通过上面的方法我们就可以通过hook获取网络的每层结构。这个就可以用来做剪枝操作。
注意!在我的代码中可以看到有这么一行:
for para in model.parameters(): para.requires_grad = True
我这里将模型的所有参数均设置为可导的,为什么要这里设置呢,这是因为我在对官方代码yolov5 6.0代码剪枝的时候,发现backbone无法剪枝,比如我想对第一层进行剪枝,会给我报KeyError的错误,最后通过仔细研究发现,在官方提供的v5模型中backbone的grad_fn均为None,利用hook无法获得网络,只能获得head部分的结构,下面显示是backbone的grad_fn为None记录的结构,:解决的办法也很简单,就是加入我上面的代码,并设置参数可导即可。
{None: BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), : BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), : Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), :
你学会了吗~