使用CNN完成MNIST手写体识别(PyTorch)
卷积神经网络(Convolutional Neural Network,简称CNN)是一种专门用于处理图像、语音、自然语言等数据的深度学习模型。CNN的特点是可以通过卷积运算提取出图像、语音等数据中的特征,从而实现对这些数据进行分类、识别等任务。
CNN的基本结构包括卷积层、池化层和全连接层。其中卷积层是CNN的核心部分,它可以通过卷积核(或滤波器)对输入数据进行卷积运算,从而提取出数据中的空间特征,如边缘、角等。卷积层的输出经过池化层的降采样处理,可以减少参数数量,提高模型的泛化能力。全连接层则将池化层输出的特征向量连接起来,通过权重矩阵进行分类、识别等任务。
CNN的训练过程通常采用反向传播算法来更新网络中的权重参数。反向传播算法可以根据损失函数的导数来逐层计算各层的误差,从而调整各层的权重参数,使得模型对训练数据的拟合效果更好。
CNN在图像识别、目标检测、人脸识别等领域都有广泛应用。其中经典的卷积神经网络模型包括LeNet、AlexNet、VGG、GoogLeNet和ResNet等。这些模型在不同的任务中都取得了很好的效果,为深度学习领域的发展做出了重要贡献。
总的来说,卷积神经网络是一种能够有效处理图像、语音等数据的深度学习模型,在计算机视觉、语音识别等领域具有广泛的应用前景。
1. 导入PyTorch库
import torch import numpy as np from torch import nn from torch.utils.data import DataLoader from torch.autograd import Variable from torchvision.datasets import mnist from torchvision import transforms from torch import optim
2. 定义CNN类
# 定义CNN class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3), nn.BatchNorm2d(16), nn.ReLU(inplace=True) ) self.layer2 = nn.Sequential( nn.Conv2d(16, 32, kernel_size=3), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.layer3 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.layer4 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2,stride=2) ) self.fc = nn.Sequential( nn.Linear(128 * 4 * 4, 1024), nn.ReLU(inplace=True), nn.Linear(1024, 128), nn.ReLU(inplace=True), nn.Linear(128, 10) ) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = x.view(x.size(0), -1) x = self.fc(x) return x
# 数据集转换 data_tf = transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] )
3. 下载数据集
# 使用内置函数下载mnist数据集 train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True) train_set, test_set
(Dataset MNIST Number of datapoints: 60000 Root location: ./data Split: Train StandardTransform Transform: Compose( ToTensor() Normalize(mean=[0.5], std=[0.5]) ), Dataset MNIST Number of datapoints: 10000 Root location: ./data Split: Test StandardTransform Transform: Compose( ToTensor() Normalize(mean=[0.5], std=[0.5]) ))
# 划分训练集与测试集 train_data = DataLoader(train_set, batch_size=100, shuffle=True) test_data = DataLoader(test_set, batch_size=100, shuffle=False) train_data, test_data
(<torch.utils.data.dataloader.DataLoader at 0x7f43c81d6eb8>, <torch.utils.data.dataloader.DataLoader at 0x7f43c81d6e10>)
4. 训练模型
# 调用卷积神经网络 net = CNN() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), 1e-1)
# 开始训练 losses = [] acces = [] eval_losses = [] eval_acces = [] nums_epoch = 1 print("开始训练......") for epoch in range(nums_epoch): print("Test:" + str(epoch)) train_loss = 0 train_acc = 0 net = net.train() i = 0 for img, label in train_data: i +=1 print("第" + str(i) + "批训练") img = Variable(img) label =Variable(label) # 前向传播 out = net(img) loss = criterion(out, label) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 记录误差 train_loss += loss.item() # 计算分类的准确率 _, pred = out.max(1) num_correct = (pred ==label).sum().item() acc = num_correct / img.shape[0] # 记录准确率 train_acc += acc losses.append(train_loss / len(train_data)) acces.append(train_acc / len(train_data)) eval_loss = 0 eval_acc = 0 # 测试集不训练 for img, label in test_data: img = Variable(img) label = Variable(label) # 前向传播 out = net(img) loss = criterion(out, label) # 记录误差 eval_loss += loss.item() # 计算分类的准确率 _, pred = out.max(1) num_correct = (pred == label).sum().item() acc = num_correct / img.shape[0] # 记录准确率 eval_acc += acc eval_losses.append(eval_loss / len(test_data)) eval_acces.append(eval_acc / len(test_data)) print('Epoch {}: \nTrain Loss: {} Train Accuracy: {} \nTest Loss: {} Test Accuarcy: {}'.format( epoch + 1, train_loss / len(train_data), train_acc / len(train_data), eval_loss / len(test_data), eval_acc / len(test_data)))
开始训练...... Test:0 第1批训练 第2批训练 第3批训练 第4批训练 第5批训练 第6批训练 第7批训练 第8批训练 第9批训练 第10批训练 第11批训练 第12批训练 第13批训练 第14批训练 第15批训练 第16批训练 第17批训练 第18批训练 第19批训练 第20批训练 第21批训练 第22批训练 第23批训练 第24批训练 第25批训练 第26批训练 第27批训练 第28批训练 第29批训练 第30批训练 第31批训练 第32批训练 第33批训练 第34批训练 第35批训练 第36批训练 第37批训练 第38批训练 第39批训练 第40批训练 第41批训练 第42批训练 第43批训练 第44批训练 第45批训练 第46批训练 第47批训练 第48批训练 第49批训练 第50批训练 第51批训练 第52批训练 第53批训练 第54批训练 第55批训练 第56批训练 第57批训练 第58批训练 第59批训练 第60批训练 第61批训练 第62批训练 第63批训练 第64批训练 第65批训练 第66批训练 第67批训练 第68批训练 第69批训练 第70批训练 第71批训练 第72批训练 第73批训练 第74批训练 第75批训练 第76批训练 第77批训练 第78批训练 第79批训练 第80批训练 第81批训练 第82批训练 第83批训练 第84批训练 第85批训练 第86批训练 第87批训练 第88批训练 第89批训练 第90批训练 第91批训练 第92批训练 第93批训练 第94批训练 第95批训练 第96批训练 第97批训练 第98批训练 第99批训练 第100批训练 第101批训练 第102批训练 第103批训练 第104批训练 第105批训练 第106批训练 第107批训练 第108批训练 第109批训练 第110批训练 第111批训练 第112批训练 第113批训练 第114批训练 第115批训练 第116批训练 第117批训练 第118批训练 第119批训练 第120批训练 第121批训练 第122批训练 第123批训练 第124批训练 第125批训练 第126批训练 第127批训练 第128批训练 第129批训练 第130批训练 第131批训练 第132批训练 第133批训练 第134批训练 第135批训练 第136批训练 第137批训练 第138批训练 第139批训练 第140批训练 第141批训练 第142批训练 第143批训练 第144批训练 第145批训练 第146批训练 第147批训练 第148批训练 第149批训练 第150批训练 第151批训练 第152批训练 第153批训练 第154批训练 第155批训练 第156批训练 第157批训练 第158批训练 第159批训练 第160批训练 第161批训练 第162批训练 第163批训练 第164批训练 第165批训练 第166批训练 第167批训练 第168批训练 第169批训练 第170批训练 第171批训练 第172批训练 第173批训练 第174批训练 第175批训练 第176批训练 第177批训练 第178批训练 第179批训练 第180批训练 第181批训练 第182批训练 第183批训练 第184批训练 第185批训练 第186批训练 第187批训练 第188批训练 第189批训练 第190批训练 第191批训练 第192批训练 第193批训练 第194批训练 第195批训练 第196批训练 第197批训练 第198批训练 第199批训练 第200批训练 第201批训练 第202批训练 第203批训练 第204批训练 第205批训练 第206批训练 第207批训练 第208批训练 第209批训练 第210批训练 第211批训练 第212批训练 第213批训练 第214批训练 第215批训练 第216批训练 第217批训练 第218批训练 第219批训练 第220批训练 第221批训练 第222批训练 第223批训练 第224批训练 第225批训练 第226批训练 第227批训练 第228批训练 第229批训练 第230批训练 第231批训练 第232批训练 第233批训练 第234批训练 第235批训练 第236批训练 第237批训练 第238批训练 第239批训练 第240批训练 第241批训练 第242批训练 第243批训练 第244批训练 第245批训练 第246批训练 第247批训练 第248批训练 第249批训练 第250批训练 第251批训练 第252批训练 第253批训练 第254批训练 第255批训练 第256批训练 第257批训练 第258批训练 第259批训练 第260批训练 第261批训练 第262批训练 第263批训练 第264批训练 第265批训练 第266批训练 第267批训练 第268批训练 第269批训练 第270批训练 第271批训练 第272批训练 第273批训练 第274批训练 第275批训练 第276批训练 第277批训练 第278批训练 第279批训练 第280批训练 第281批训练 第282批训练 第283批训练 第284批训练 第285批训练 第286批训练 第287批训练 第288批训练 第289批训练 第290批训练 第291批训练 第292批训练 第293批训练 第294批训练 第295批训练 第296批训练 第297批训练 第298批训练 第299批训练 第300批训练 第301批训练 第302批训练 第303批训练 第304批训练 第305批训练 第306批训练 第307批训练 第308批训练 第309批训练 第310批训练 第311批训练 第312批训练 第313批训练 第314批训练 第315批训练 第316批训练 第317批训练 第318批训练 第319批训练 第320批训练 第321批训练 第322批训练 第323批训练 第324批训练 第325批训练 第326批训练 第327批训练 第328批训练 第329批训练 第330批训练 第331批训练 第332批训练 第333批训练 第334批训练 第335批训练 第336批训练 第337批训练 第338批训练 第339批训练 第340批训练 第341批训练 第342批训练 第343批训练 第344批训练 第345批训练 第346批训练 第347批训练 第348批训练 第349批训练 第350批训练 第351批训练 第352批训练 第353批训练 第354批训练 第355批训练 第356批训练 第357批训练 第358批训练 第359批训练 第360批训练 第361批训练 第362批训练 第363批训练 第364批训练 第365批训练 第366批训练 第367批训练 第368批训练 第369批训练 第370批训练 第371批训练 第372批训练 第373批训练 第374批训练 第375批训练 第376批训练 第377批训练 第378批训练 第379批训练 第380批训练 第381批训练 第382批训练 第383批训练 第384批训练 第385批训练 第386批训练 第387批训练 第388批训练 第389批训练 第390批训练 第391批训练 第392批训练 第393批训练 第394批训练 第395批训练 第396批训练 第397批训练 第398批训练 第399批训练 第400批训练 第401批训练 第402批训练 第403批训练 第404批训练 第405批训练 第406批训练 第407批训练 第408批训练 第409批训练 第410批训练 第411批训练 第412批训练 第413批训练 第414批训练 第415批训练 第416批训练 第417批训练 第418批训练 第419批训练 第420批训练 第421批训练 第422批训练 第423批训练 第424批训练 第425批训练 第426批训练 第427批训练 第428批训练 第429批训练 第430批训练 第431批训练 第432批训练 第433批训练 第434批训练 第435批训练 第436批训练 第437批训练 第438批训练 第439批训练 第440批训练 第441批训练 第442批训练 第443批训练 第444批训练 第445批训练 第446批训练 第447批训练 第448批训练 第449批训练 第450批训练 第451批训练 第452批训练 第453批训练 第454批训练 第455批训练 第456批训练 第457批训练 第458批训练 第459批训练 第460批训练 第461批训练 第462批训练 第463批训练 第464批训练 第465批训练 第466批训练 第467批训练 第468批训练 第469批训练 第470批训练 第471批训练 第472批训练 第473批训练 第474批训练 第475批训练 第476批训练 第477批训练 第478批训练 第479批训练 第480批训练 第481批训练 第482批训练 第483批训练 第484批训练 第485批训练 第486批训练 第487批训练 第488批训练 第489批训练 第490批训练 第491批训练 第492批训练 第493批训练 第494批训练 第495批训练 第496批训练 第497批训练 第498批训练 第499批训练 第500批训练 第501批训练 第502批训练 第503批训练 第504批训练 第505批训练 第506批训练 第507批训练 第508批训练 第509批训练 第510批训练 第511批训练 第512批训练 第513批训练 第514批训练 第515批训练 第516批训练 第517批训练 第518批训练 第519批训练 第520批训练 第521批训练 第522批训练 第523批训练 第524批训练 第525批训练 第526批训练 第527批训练 第528批训练 第529批训练 第530批训练 第531批训练 第532批训练 第533批训练 第534批训练 第535批训练 第536批训练 第537批训练 第538批训练 第539批训练 第540批训练 第541批训练 第542批训练 第543批训练 第544批训练 第545批训练 第546批训练 第547批训练 第548批训练 第549批训练 第550批训练 第551批训练 第552批训练 第553批训练 第554批训练 第555批训练 第556批训练 第557批训练 第558批训练 第559批训练 第560批训练 第561批训练 第562批训练 第563批训练 第564批训练 第565批训练 第566批训练 第567批训练 第568批训练 第569批训练 第570批训练 第571批训练 第572批训练 第573批训练 第574批训练 第575批训练 第576批训练 第577批训练 第578批训练 第579批训练 第580批训练 第581批训练 第582批训练 第583批训练 第584批训练 第585批训练 第586批训练 第587批训练 第588批训练 第589批训练 第590批训练 第591批训练 第592批训练 第593批训练 第594批训练 第595批训练 第596批训练 第597批训练 第598批训练 第599批训练 第600批训练 Epoch 1: Train Loss: 0.14750646080588922 Train Accuracy: 0.9542000000000053 Test Loss: 0.04495963536784984 Test Accuarcy: 0.9845999999999998