基于Fashion-MNIST数据集的模型剪枝(上)

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: 1. 介绍1.1 背景介绍目前在深度学习中存在一些困境,对于移动是设备来说,主要是算不好;穿戴设备算不来;数据中心,大多数人又算不起 。这就是做模型做压缩与加速的初衷。

1. 介绍

1.1 背景介绍

目前在深度学习中存在一些困境,对于移动是设备来说,主要是算不好;穿戴设备算不来;数据中心,大多数人又算不起 。这就是做模型做压缩与加速的初衷。当然这也是需要理论基础作为支撑的,对与许多网络结构中,如VGG-16网络,参数数量1亿3干多万,占用500MB空间,需要进行309亿次浮点运算才能完成一次图像识别任务,这说明了模型压缩是必要的。我们在谈谈它的可行性。论文Predicting parameters in deep learning提出,其实在很多深度的神经网络中存在着显著的冗余。仅仅使用很少一部分(5%)权值就足以预测剩余的权值。该论文还提出这些剩下的权值甚至可以直接不用被学习。也就是说,仅仅训练一小部分原来的权值参数就有可能达到和原来网络相近甚至超过原来网络的性能(可以看作一种正则化)。我们的最终目的是最大程度的减小模型复杂度,减少模型存储需要的空间,也致力于加速模型的训练和推测。


目前我们可以将模型压缩的方法分为前端压缩和后端压缩。前端压缩有知识蒸馏,紧凑模型设计和滤波器级别的剪枝等。后端压缩(极大的改造网络的结构)有低秩近似,不加限制的剪枝和参数的量化和二值化等。但是各种方法也有自己的特点。

方法名称 描述 应用场景 方法细节
低秩分解 使用矩阵对参数进行分解估计 卷积层和全连接层 标准化的途径,很容易实施,支持从零训练和预训练
剪枝 删除对准确率影响不大的参数 卷积层和全连接层 对不同的设置具有鲁棒性,可以达到较好效果,支持从零训练和预训练
量化 减少每一个权值所需要的比特数来压缩网络 卷积层和全连接层 多依赖于二进制编码方式,适合在FPGA、单片机等平台上部署
知识蒸馏 训练一个更紧凑的神经网络来从大的模型蒸馏知识 卷积层和全连接层 模型表现对应用程序和网络结构比较敏感,只能从零开始训练
转移、紧凑卷积核 设计特别的卷积核来保存参数 只有卷积层 算法依赖于应用程序,通常可以取得好的表现,只能从零开始训练

下面我基于Finsh_mnist数据集的模型剪枝教程,仅做练习参考。

1.2 剪枝方法

本文采用了一种标准的剪枝框架第一步是预训练,即按照正常方法训练初始模型。尽管与传统的训练有着相同的形式,但是不同于传统训练期望获取最终学习到的权值,预训练的目的是确定哪些连接比较重要。一种比较简单的想法是绝对值更大的权值往往更重要,这种做法应用广泛并能取得不错的效果。第二步是剪枝,一般的做法是确定一个阈值,绝对值大小低于國值的连接将会被去除。

1.3 数据集介绍

Fashion-MNIST是一个替代MNIST手写数字集的图像数据集。它是由Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自10种类别的共7万个不同商品的正面图片。Fashion-MNIST的大小、格式和训练集/测试集划分与原始的MNIST完全一致。60000/10000的训练测试数据划分,28x28的灰度图片。你可以直接用它来测试你的机器学习和深度学习算法性能,且不需要改动任何的代码。

1.4 导入相关的包

由于设备不给力,我使用了google.colab进行实验。

# 导入相关包
from google.colab import drive
drive.mount('/content/gdrive')
import os
os.chdir("/content/gdrive/My Drive/Colab Notebooks/pytorch深度学习")
import torch
from torch import nn
import torch.nn.functional as F
print(torch.__version__)
1.5.0+cu101
• 1

2. 构建网络

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1)
        self.relu3 = nn.ReLU(inplace=True)
        self.linear1 = nn.Linear(7 * 7 * 64, 128)
        self.linear2 = nn.Linear(128, 10)
        self.loss = nn.CrossEntropyLoss()
    def forward(self, data):
        out = self.maxpool1(self.relu1(self.conv1(data)))
        out = self.maxpool2(self.relu2(self.conv2(out)))
        out = self.relu3(self.conv3(out))
        out = out.view(out.size(0), -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out
    def get_loss(self, output, label):
        return self.loss(output, label)
 if __name__ == '__main__':
    net = MyNet()
    for p in net.conv1.parameters():
        print(p.data.size())
    for p in net.linear1.parameters():
        print(p.data.size())
torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([128, 3136])
torch.Size([128])

3. 训练

3.1 数据载入

#图像的预处理归一化[0,1],transforms.Normalize即image=(image-mean)/std到[-1,1]
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
#加载训练集,不是独热编码
train_data = DataLoader(datasets.FashionMNIST("./datasets/", train=True, transform=trans, download=True),
                              batch_size=100, shuffle=True, drop_last=True)
#加载测试集
test_data = DataLoader(datasets.FashionMNIST("./datasets/", train=False, transform=trans, download=True),
                               batch_size=100, shuffle=True, drop_last=True)

3.1 训练方法

class Trainer:
    def __init__(self, save_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.save_path = save_path
        self.net = MyNet().to(self.device)
        self.optimizer = torch.optim.Adam(self.net.parameters())
        self.net.train()
    def evaluate_accuracy(self,test_data):
        acc_sum, n = 0.0, 0
        for X, y in test_data:
            X, y = X.to(self.device), y.to(self.device)
            if isinstance(self.net, torch.nn.Module):
                net.eval() # 评估模式, 这会关闭dropout
                acc_sum += (self.net(X).argmax(dim=1) == y).float().sum().item()
                net.train() # 改回训练模式
            else: # 自定义的模型
                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (self.net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                else:
                    acc_sum += (self.net(X).argmax(dim=1) == y).float().sum().item() 
            n += y.shape[0]
        return acc_sum / n
    def train(self,train_data,test_data,epochs):
        for epoch in range(1, epochs):
            total = 0
            train_acc_sum,train_l_sum,n,start = 0.0,0.0,0,time.time()
            for i, (data, label) in enumerate(train_data):
                data, label = data.to(self.device), label.to(self.device)
                output = self.net(data)
                loss = self.net.get_loss(output, label)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total += len(data)
                #训练损失
                train_l_sum += loss.item()
                train_acc_sum += ((output.argmax(dim=1)) == label).sum().item()
                n += 100  
                progress = math.ceil(i / len(train_data) * 50)
                print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
                      (epoch, total, len(train_data.dataset),
                       '-' * progress + '>', progress * 2), end='')
            test_acc = self.evaluate_accuracy(test_data)
            print("\nepoch %d,loss %.4f, train_acc %.3f, test_acc %.3f,time %.1f sec"
              %(epoch,train_l_sum/n,train_acc_sum/n,test_acc,time.time()-start))
            torch.save(self.net.state_dict(), self.save_path)
if __name__ == '__main__':
    trainer = Trainer("./model/finsh_minst_net.pth")
    trainer.train(train_data,test_data,epochs=10)
Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%
epoch 1,loss 0.0045, train_acc 0.836, test_acc 0.877,time 11.8 sec
Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
epoch 2,loss 0.0028, train_acc 0.899, test_acc 0.897,time 11.8 sec
Train epoch 3: 60000/60000, [-------------------------------------------------->] 100%
epoch 3,loss 0.0024, train_acc 0.913, test_acc 0.899,time 12.0 sec
Train epoch 4: 60000/60000, [-------------------------------------------------->] 100%
epoch 4,loss 0.0021, train_acc 0.923, test_acc 0.909,time 11.8 sec
Train epoch 5: 60000/60000, [-------------------------------------------------->] 100%
epoch 5,loss 0.0019, train_acc 0.930, test_acc 0.917,time 12.1 sec
Train epoch 6: 60000/60000, [-------------------------------------------------->] 100%
epoch 6,loss 0.0017, train_acc 0.938, test_acc 0.912,time 11.9 sec
Train epoch 7: 60000/60000, [-------------------------------------------------->] 100%
epoch 7,loss 0.0015, train_acc 0.945, test_acc 0.921,time 11.8 sec
Train epoch 8: 60000/60000, [-------------------------------------------------->] 100%
epoch 8,loss 0.0013, train_acc 0.951, test_acc 0.922,time 11.7 sec
Train epoch 9: 60000/60000, [-------------------------------------------------->] 100%
epoch 9,loss 0.0012, train_acc 0.957, test_acc 0.918,time 11.9 sec


相关文章
|
7月前
|
机器学习/深度学习
CNN模型识别cifar数据集
构建简单的CNN模型识别cifar数据集。经过几天的简单学习,尝试写了一个简单的CNN模型通过cifar数据集进行训练。效果一般,测试集上的的表现并不好,说明模型的构建不怎么样。# -*- coding = utf-8 -*-# @Time : 2020/10/16 16:19# @Author : tcc# @File : cifar_test.py# @Software : pycha...
32 0
|
25天前
|
机器学习/深度学习 算法 数据挖掘
SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型
SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型
|
1月前
|
机器学习/深度学习 数据可视化 PyTorch
利用PyTorch实现基于MNIST数据集的手写数字识别
利用PyTorch实现基于MNIST数据集的手写数字识别
25 2
|
2月前
|
机器学习/深度学习 数据可视化 算法
基于MLP完成CIFAR-10数据集和UCI wine数据集的分类
基于MLP完成CIFAR-10数据集和UCI wine数据集的分类
38 0
|
9月前
|
机器学习/深度学习 数据可视化 自动驾驶
图像分类 | 基于 MNIST 数据集
图像分类 | 基于 MNIST 数据集
|
11月前
|
机器学习/深度学习 并行计算
探索用卷积神经网络实现MNIST数据集分类
探索用卷积神经网络实现MNIST数据集分类
103 0
|
12月前
|
机器学习/深度学习 Web App开发 人工智能
一个项目帮你了解数据集蒸馏Dataset Distillation
一个项目帮你了解数据集蒸馏Dataset Distillation
184 0
|
TensorFlow 算法框架/工具
实现mnist手写数字识别
实现mnist手写数字识别
|
机器学习/深度学习
LSTM应用于MNIST数据集分类
LSTM网络是序列模型,一般比较适合处理序列问题。这里把它用于手写数字图片的分类,其实就相当于把图片看作序列。
270 0
LSTM应用于MNIST数据集分类
|
机器学习/深度学习 数据中心
基于Fashion-MNIST数据集的模型剪枝(下)
1. 介绍 1.1 背景介绍 目前在深度学习中存在一些困境,对于移动是设备来说,主要是算不好;穿戴设备算不来;数据中心,大多数人又算不起 。这就是做模型做压缩与加速的初衷。
118 0
基于Fashion-MNIST数据集的模型剪枝(下)

热门文章

最新文章