1. register_hook
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数。
钩子函数包括Variable的钩子和nn.Module钩子,用法相似。
import torch from torch.autograd import Variable grad_list = [] def print_grad(grad): grad_list.append(grad) x = Variable(torch.tensor([1.,2.]), requires_grad=True) y = x+2 print("x:", x) print("y:", y) z = torch.mean(torch.pow(y, 2)) print("z:", z) lr = 1e-3 y.register_hook(print_grad) z.backward() x.data -= lr*x.grad.data print("y.grad:", grad_list) print("x.grad:", x.grad) print('new x is:',x)
输出:
x: tensor([1., 2.], requires_grad=True) y: tensor([3., 4.], grad_fn=<AddBackward0>) z: tensor(12.5000, grad_fn=<MeanBackward0>) y.grad: [tensor([3., 4.])] x.grad: tensor([3., 4.]) new x is: tensor([0.9970, 1.9960], requires_grad=True)
所以z对x求偏导的结果也是[3,4],可以看见输出的结果x.grad: tensor([3., 4.])也是正确的。
同时这里需要注意,如果不使用register_hook函数,是没有办法获取y的中间梯度信息的。也就是z对于x的grad是存在的,但是z对于中间变量y的grad是不存在的,也就验证了Pytorch会自动舍弃图计算的中间结果这句话。
2. register_forward_hook
这两个函数的功能类似于variable函数的register_hook,可在module前向传播或反向传播时注册钩子。
每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None
钩子函数不应修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用钩子技术就更合适一些。下面考虑一种场景,有一个预训练好的模型,需要提取模型的某一层(不是最后一层)的输出作为特征进行分类,但又不希望修改其原有的模型定义文件,这时就可以利用钩子函数。
下面以pytorch实现的resnet来测试,其官方的代码中farward函数如下所示,这里想提出去输入layer4的特征矩阵(也就是layer3的输出特征矩阵),以及layer4输出的特征矩阵:
# resnet50的forward函数如下所示: def _forward_impl(self, x: Tensor) -> Tensor: # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) # layer3的Bottleneck如下所示,特征矩阵的channels输出为1024 Bottleneck( (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) ) # layer4Bottleneck如下所示,特征矩阵的channels输出为2048 Bottleneck( (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) ) # 以上结构可以通过print(model)来打印出来查看
这里我想提取最后一层特征层的输出,也就是layer4层输入输出的特征矩阵,参考代码如下:
import torch from torchvision.models import resnet50 features_input = [] features_output = [] def hook(modele, input, output): features_input.append(input) features_output.append(output) model = resnet50(pretrained=True) handle = model.layer4.register_forward_hook(hook) x = torch.rand([1, 3, 224, 224]) out = model(x) handle.remove() print("out.shape:", out.shape) print("features_input.shape:", features_input[0][0].shape) print("features_output.shape:", features_output[0].shape)
输出:
out.shape: torch.Size([1, 1000]) features_input.shape: torch.Size([1, 1024, 14, 14]) features_output.shape: torch.Size([1, 2048, 7, 7])
可以看见,特征维度的结果和我们计算出的结果是一样的
3. register_backward_hook
以上两个函数的功能类似于variable函数的register_hook,可在module前向传播或反向传播时注册钩子。
每次前向传播执行结束后会执行钩子函数(hook)。前向传播的钩子函数具有如下形式:hook(module, input, output) -> None,而反向传播则具有如下形式:hook(module, grad_input, grad_output) -> Tensor or None。
对于需要获得网络方向传播的梯度信息可以使用这个函数,通过反向传播的梯度信息在cam这种可解释性的深度学习中会使用到。关于grad-cam的介绍与使用可以看一下我前面的一篇文章:grad-cam的简单逻辑实现以及效果展示
同样的,这里想要获得resnet50中,layer4层的输入输出的梯度信息:
import torch from torchvision.models import resnet50 grads_input = [] grads_output = [] def hook(module, grad_input, grad_output): grads_input.append(grad_input) grads_output.append(grad_output) model = resnet50(pretrained=True) handle = model.layer4.register_backward_hook(hook) x = torch.rand([1, 3, 224, 224]) out = model(x) out[0][0].backward() handle.remove() print("out.shape:", out.shape) print("grads_input.shape:", grads_input[0][0].shape) print("grads_output.shape:", grads_output[0][0].shape)
输出:
out.shape: torch.Size([1, 1000]) grads_input.shape: torch.Size([1, 2048, 7, 7]) grads_output.shape: torch.Size([1, 2048, 7, 7])
可以知道,由于对于layer4来说,其输出特征矩阵维度是 torch.Size([1, 2048, 7, 7]),所以对于其方向传播的输入就应该是 torch.Size([1, 2048, 7, 7])。
下面尝试,能否多注册两个方向传播的钩子函数,来同时获得layer3与layer4的方向梯度信息:
import torch from torchvision.models import resnet50 grads_ly3_input = [] grads_ly3_output = [] def hook_ly3(module, grad_input, grad_output): grads_ly3_input.append(grad_input) grads_ly3_output.append(grad_output) grads_ly4_input = [] grads_ly4_output = [] def hook_ly4(module, grad_input, grad_output): grads_ly4_input.append(grad_input) grads_ly4_output.append(grad_output) model = resnet50(pretrained=True) handle_ly3 = model.layer3.register_backward_hook(hook_ly3) handle_ly4 = model.layer4.register_backward_hook(hook_ly4) x = torch.rand([1, 3, 224, 224]) out = model(x) out[0][0].backward() handle_ly3.remove() handle_ly4.remove() print("out.shape:", out.shape) print("grads_ly3_input.shape:", grads_ly3_input[0][0].shape) print("grads_ly3_output.shape:", grads_ly3_output[0][0].shape) print("grads_ly4_input.shape:", grads_ly4_input[0][0].shape) print("grads_ly4_output.shape:", grads_ly4_output[0][0].shape)
输出:
out.shape: torch.Size([1, 1000]) grads_ly3_input.shape: torch.Size([1, 1024, 14, 14]) grads_ly3_output.shape: torch.Size([1, 1024, 14, 14]) grads_ly4_input.shape: torch.Size([1, 2048, 7, 7]) grads_ly4_output.shape: torch.Size([1, 2048, 7, 7])
由于对于layer3来说,其输出特征矩阵维度是 torch.Size([1, 1024, 14, 14]),所以对于其方向传播的输入就应该是 torch.Size([1, 1024, 14, 14]);由于对于layer4来说,其输出特征矩阵维度是 torch.Size([1, 2048, 7, 7]),所以对于其方向传播的输入就应该是 torch.Size([1, 2048, 7, 7])。
参考资料:
Pytorch_hook机制的理解及利用register_forward_hook(hook)中间层输出
『PyTorch』第十六弹_hook技术