基于Fashion-MNIST数据集的模型剪枝(下)

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: 1. 介绍1.1 背景介绍目前在深度学习中存在一些困境,对于移动是设备来说,主要是算不好;穿戴设备算不来;数据中心,大多数人又算不起 。这就是做模型做压缩与加速的初衷。

4. 对模型进行剪枝

4.1 构建剪枝网络

import torch
import torch.nn.utils.prune as prune
class Pruning:
    #net_path是修建的模型,amount是模型的修建率
    def __init__(self, net_path, amount):
        self.net = MyNet()
        #加载模型
        self.net.load_state_dict(torch.load(net_path))
        #将模型都定义为元组,这是全局修剪的方法
        self.parameters_to_prune = (
            (self.net.conv1, 'weight'),
            (self.net.conv2, 'weight'),
            (self.net.conv3, 'weight'),
            (self.net.linear1, 'weight'),
            (self.net.linear2, 'weight'),
        )
        self.amount = amount
    def pruning(self):
      #全局修剪参数,方法是修剪绝对值参数
        prune.global_unstructured(
            self.parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=self.amount,
        )
        # print(self.net.state_dict().keys())
        # 删除weight_orig 、weight_mask以及forward_pre_hook
        prune.remove(self.net.conv1, 'weight')
        prune.remove(self.net.conv2, 'weight')
        prune.remove(self.net.conv3, 'weight')
        prune.remove(self.net.linear1, 'weight')
        prune.remove(self.net.linear2, 'weight')
        print(
            "Sparsity in conv1.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.conv1.weight == 0))
                / float(self.net.conv1.weight.nelement())
            )
        )
        print(
            "Sparsity in conv2.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.conv2.weight == 0))
                / float(self.net.conv2.weight.nelement())
            )
        )
        print(
            "Sparsity in conv3.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.conv3.weight == 0))
                / float(self.net.conv3.weight.nelement())
            )
        )
        print(
            "Sparsity in linear1.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.linear1.weight == 0))
                / float(self.net.linear1.weight.nelement())
            )
        )
        print(
            "Sparsity in linear2.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.linear2.weight == 0))
                / float(self.net.linear2.weight.nelement())
            )
        )
        print(
            "Global sparsity: {:.2f}%".format(
                100. * float(
                    torch.sum(self.net.conv1.weight == 0)
                    + torch.sum(self.net.conv2.weight == 0)
                    + torch.sum(self.net.conv3.weight == 0)
                    + torch.sum(self.net.linear1.weight == 0)
                    + torch.sum(self.net.linear2.weight == 0)
                )
                / float(
                    self.net.conv1.weight.nelement()
                    + self.net.conv2.weight.nelement()
                    + self.net.conv3.weight.nelement()
                    + self.net.linear1.weight.nelement()
                    + self.net.linear2.weight.nelement()
                )
            )
        )
        # torch.save(self.net.state_dict(), "models/pruned_net_with_conv.pth")
        torch.save(self.net.state_dict(), f"./model/pruned_net_with_torch_{self.amount:.1f}_l1.pth")
if __name__ == '__main__':
    for i in range(1, 10):
        pruning = Pruning("./model/finsh_minst_net.pth", 0.1 * i)
        pruning.pruning()

5. 检测

class Detector:
    def __init__(self, net_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.net = MyNet().to(self.device)
        self.map_location = None if torch.cuda.is_available() else lambda storage, loc: storage
        self.net.load_state_dict(torch.load(net_path, map_location=self.map_location))
        self.net.eval()
    def detect(self,test_data):
        test_loss = 0
        correct = 0
        start = time.time()
        with torch.no_grad():
            for data, label in test_data:
                data, label = data.to(self.device), label.to(self.device)
                output = self.net(data)
                test_loss += self.net.get_loss(output, label)
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(label.view_as(pred)).sum().item()
        end = time.time()
        print(f"total time:{end - start}")
        test_loss /= len(test_data.dataset)
        print('Test: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data.dataset),
            100. * correct / len(test_data.dataset)))
        #返回的损失和正确率
        return [test_loss,correct/len(test_data.dataset)]
if __name__ == '__main__':
    print("./model/finsh_minst_net.pth")
    test_loss,accuracy,Parameter_compression_ratio = [],[],[]
    detector1 = Detector("./model/finsh_minst_net.pth")
    loss,acc = detector1.detect(test_data)
    test_loss.append(loss)
    accuracy.append(acc)   
    Parameter_compression_ratio.append(0)
    for i in range(1, 10):
        amount = 0.1 * i
        print(f"./model/pruned_net_with_torch_{amount:.1f}_l1.pth")
        detector1 = Detector(f"./model/pruned_net_with_torch_{amount:.1f}_l1.pth")
        loss,acc = detector1.detect(test_data)
        test_loss.append(loss)
        accuracy.append(acc)   
        Parameter_compression_ratio.append(amount)
./model/finsh_minst_net.pth
total time:1.6475636959075928
Test: average loss: 0.0027, accuracy: 9179/10000 (92%)
./model/pruned_net_with_torch_0.1_l1.pth
total time:1.3875453472137451
Test: average loss: 0.0027, accuracy: 9179/10000 (92%)
./model/pruned_net_with_torch_0.2_l1.pth
total time:1.5390675067901611
Test: average loss: 0.0027, accuracy: 9179/10000 (92%)
./model/pruned_net_with_torch_0.3_l1.pth
total time:1.356163501739502
Test: average loss: 0.0026, accuracy: 9178/10000 (92%)
./model/pruned_net_with_torch_0.4_l1.pth
total time:1.4721436500549316
Test: average loss: 0.0026, accuracy: 9163/10000 (92%)
./model/pruned_net_with_torch_0.5_l1.pth
total time:1.429352045059204
Test: average loss: 0.0026, accuracy: 9134/10000 (91%)
./model/pruned_net_with_torch_0.6_l1.pth
total time:1.3589565753936768
Test: average loss: 0.0026, accuracy: 9119/10000 (91%)
./model/pruned_net_with_torch_0.7_l1.pth
total time:1.3456928730010986
Test: average loss: 0.0028, accuracy: 9026/10000 (90%)
./model/pruned_net_with_torch_0.8_l1.pth
total time:1.351386308670044
Test: average loss: 0.0046, accuracy: 8644/10000 (86%)
./model/pruned_net_with_torch_0.9_l1.pth
total time:1.4840266704559326
Test: average loss: 0.0135, accuracy: 7021/10000 (70%)
import numpy as np
import matplotlib.pyplot as plt
fig,ax = plt.subplots(1,2,figsize=(9,5))
ax1 = plt.subplot(121)  #绘制子图1对象
ax2 = plt.subplot(122)  #绘制子图2对象
x = Parameter_compression_ratio
y = accuracy
y2 = test_loss
ax1.plot(x,y,color='red',label='accuracy')
ax2.plot(x,y2,color='blue',label='test_loss')
ax1.legend()
ax2.legend()
plt.show()

相关文章
|
6天前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
7月前
|
机器学习/深度学习
CNN模型识别cifar数据集
构建简单的CNN模型识别cifar数据集。经过几天的简单学习,尝试写了一个简单的CNN模型通过cifar数据集进行训练。效果一般,测试集上的的表现并不好,说明模型的构建不怎么样。# -*- coding = utf-8 -*-# @Time : 2020/10/16 16:19# @Author : tcc# @File : cifar_test.py# @Software : pycha...
32 0
|
6天前
|
机器学习/深度学习 算法 数据挖掘
SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型
SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型
|
6天前
|
机器学习/深度学习 数据可视化 PyTorch
利用PyTorch实现基于MNIST数据集的手写数字识别
利用PyTorch实现基于MNIST数据集的手写数字识别
28 2
|
6天前
|
机器学习/深度学习 数据可视化 算法
基于MLP完成CIFAR-10数据集和UCI wine数据集的分类
基于MLP完成CIFAR-10数据集和UCI wine数据集的分类
42 0
|
9月前
|
机器学习/深度学习 数据可视化 自动驾驶
图像分类 | 基于 MNIST 数据集
图像分类 | 基于 MNIST 数据集
|
11月前
|
机器学习/深度学习 并行计算
探索用卷积神经网络实现MNIST数据集分类
探索用卷积神经网络实现MNIST数据集分类
105 0
|
机器学习/深度学习 Web App开发 人工智能
一个项目帮你了解数据集蒸馏Dataset Distillation
一个项目帮你了解数据集蒸馏Dataset Distillation
189 0
|
TensorFlow 算法框架/工具
实现mnist手写数字识别
实现mnist手写数字识别
|
机器学习/深度学习
LSTM应用于MNIST数据集分类
LSTM网络是序列模型,一般比较适合处理序列问题。这里把它用于手写数字图片的分类,其实就相当于把图片看作序列。
279 0
LSTM应用于MNIST数据集分类

相关实验场景

更多