4. 对模型进行剪枝
4.1 构建剪枝网络
import torch import torch.nn.utils.prune as prune class Pruning: #net_path是修建的模型,amount是模型的修建率 def __init__(self, net_path, amount): self.net = MyNet() #加载模型 self.net.load_state_dict(torch.load(net_path)) #将模型都定义为元组,这是全局修剪的方法 self.parameters_to_prune = ( (self.net.conv1, 'weight'), (self.net.conv2, 'weight'), (self.net.conv3, 'weight'), (self.net.linear1, 'weight'), (self.net.linear2, 'weight'), ) self.amount = amount def pruning(self): #全局修剪参数,方法是修剪绝对值参数 prune.global_unstructured( self.parameters_to_prune, pruning_method=prune.L1Unstructured, amount=self.amount, ) # print(self.net.state_dict().keys()) # 删除weight_orig 、weight_mask以及forward_pre_hook prune.remove(self.net.conv1, 'weight') prune.remove(self.net.conv2, 'weight') prune.remove(self.net.conv3, 'weight') prune.remove(self.net.linear1, 'weight') prune.remove(self.net.linear2, 'weight') print( "Sparsity in conv1.weight: {:.2f}%".format( 100. * float(torch.sum(self.net.conv1.weight == 0)) / float(self.net.conv1.weight.nelement()) ) ) print( "Sparsity in conv2.weight: {:.2f}%".format( 100. * float(torch.sum(self.net.conv2.weight == 0)) / float(self.net.conv2.weight.nelement()) ) ) print( "Sparsity in conv3.weight: {:.2f}%".format( 100. * float(torch.sum(self.net.conv3.weight == 0)) / float(self.net.conv3.weight.nelement()) ) ) print( "Sparsity in linear1.weight: {:.2f}%".format( 100. * float(torch.sum(self.net.linear1.weight == 0)) / float(self.net.linear1.weight.nelement()) ) ) print( "Sparsity in linear2.weight: {:.2f}%".format( 100. * float(torch.sum(self.net.linear2.weight == 0)) / float(self.net.linear2.weight.nelement()) ) ) print( "Global sparsity: {:.2f}%".format( 100. * float( torch.sum(self.net.conv1.weight == 0) + torch.sum(self.net.conv2.weight == 0) + torch.sum(self.net.conv3.weight == 0) + torch.sum(self.net.linear1.weight == 0) + torch.sum(self.net.linear2.weight == 0) ) / float( self.net.conv1.weight.nelement() + self.net.conv2.weight.nelement() + self.net.conv3.weight.nelement() + self.net.linear1.weight.nelement() + self.net.linear2.weight.nelement() ) ) ) # torch.save(self.net.state_dict(), "models/pruned_net_with_conv.pth") torch.save(self.net.state_dict(), f"./model/pruned_net_with_torch_{self.amount:.1f}_l1.pth") if __name__ == '__main__': for i in range(1, 10): pruning = Pruning("./model/finsh_minst_net.pth", 0.1 * i) pruning.pruning()
5. 检测
class Detector: def __init__(self, net_path): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.net = MyNet().to(self.device) self.map_location = None if torch.cuda.is_available() else lambda storage, loc: storage self.net.load_state_dict(torch.load(net_path, map_location=self.map_location)) self.net.eval() def detect(self,test_data): test_loss = 0 correct = 0 start = time.time() with torch.no_grad(): for data, label in test_data: data, label = data.to(self.device), label.to(self.device) output = self.net(data) test_loss += self.net.get_loss(output, label) pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(label.view_as(pred)).sum().item() end = time.time() print(f"total time:{end - start}") test_loss /= len(test_data.dataset) print('Test: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_data.dataset), 100. * correct / len(test_data.dataset))) #返回的损失和正确率 return [test_loss,correct/len(test_data.dataset)] if __name__ == '__main__': print("./model/finsh_minst_net.pth") test_loss,accuracy,Parameter_compression_ratio = [],[],[] detector1 = Detector("./model/finsh_minst_net.pth") loss,acc = detector1.detect(test_data) test_loss.append(loss) accuracy.append(acc) Parameter_compression_ratio.append(0) for i in range(1, 10): amount = 0.1 * i print(f"./model/pruned_net_with_torch_{amount:.1f}_l1.pth") detector1 = Detector(f"./model/pruned_net_with_torch_{amount:.1f}_l1.pth") loss,acc = detector1.detect(test_data) test_loss.append(loss) accuracy.append(acc) Parameter_compression_ratio.append(amount)
./model/finsh_minst_net.pth total time:1.6475636959075928 Test: average loss: 0.0027, accuracy: 9179/10000 (92%) ./model/pruned_net_with_torch_0.1_l1.pth total time:1.3875453472137451 Test: average loss: 0.0027, accuracy: 9179/10000 (92%) ./model/pruned_net_with_torch_0.2_l1.pth total time:1.5390675067901611 Test: average loss: 0.0027, accuracy: 9179/10000 (92%) ./model/pruned_net_with_torch_0.3_l1.pth total time:1.356163501739502 Test: average loss: 0.0026, accuracy: 9178/10000 (92%) ./model/pruned_net_with_torch_0.4_l1.pth total time:1.4721436500549316 Test: average loss: 0.0026, accuracy: 9163/10000 (92%) ./model/pruned_net_with_torch_0.5_l1.pth total time:1.429352045059204 Test: average loss: 0.0026, accuracy: 9134/10000 (91%) ./model/pruned_net_with_torch_0.6_l1.pth total time:1.3589565753936768 Test: average loss: 0.0026, accuracy: 9119/10000 (91%) ./model/pruned_net_with_torch_0.7_l1.pth total time:1.3456928730010986 Test: average loss: 0.0028, accuracy: 9026/10000 (90%) ./model/pruned_net_with_torch_0.8_l1.pth total time:1.351386308670044 Test: average loss: 0.0046, accuracy: 8644/10000 (86%) ./model/pruned_net_with_torch_0.9_l1.pth total time:1.4840266704559326 Test: average loss: 0.0135, accuracy: 7021/10000 (70%)
import numpy as np import matplotlib.pyplot as plt fig,ax = plt.subplots(1,2,figsize=(9,5)) ax1 = plt.subplot(121) #绘制子图1对象 ax2 = plt.subplot(122) #绘制子图2对象 x = Parameter_compression_ratio y = accuracy y2 = test_loss ax1.plot(x,y,color='red',label='accuracy') ax2.plot(x,y2,color='blue',label='test_loss') ax1.legend() ax2.legend() plt.show()