PyTorch使用神经网络进行手写数字识别实战(附源码,包括损失图像和准确率图像)

简介: PyTorch使用神经网络进行手写数字识别实战(附源码,包括损失图像和准确率图像)

全部源码请点赞关注收藏后评论区留言即可~~~

下面使用torchvision.datasets.MNIST构建手写数字数据集。

1:数据预处理

PyTorch提供了torchvision.transforms用于处理数据及数据增强,它可以将数据从[0,255]映射到[0,1]

2:读取训练数据

准备好处理数据的流程后,就可以读取用于训练的数据了,torch.util.data.DataLoader提供了迭代数据,随机抽取数据,批量化数据等等功能 读取效果如下

预处理过后的数据如下

3:构建神经网络模型

下面构建用于识别手写数字的神经网络模型

class MLP(nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.inputlayer=nn.Sequential(nn.Linear(28*28,256),nn.ReLU(),nn.Dropout(0.2))
        self.hiddenlayer=nn.Sequential(nn.Linear(256,256),nn.ReLU(),nn.Dropout(0.2))
        self.outputlayer=nn.Sequential(nn.Linear(256,10))
    def forward(self,x):
        x=x.view(x.size(0),-1)
        x=self.inputlayer(x)
        x=self.hiddenlayer(x)
        x=self.outputlayer(x)
        return x

可以直接通过打印nn.Module的对象看到其网络结构

4:模型评估

在准备好数据和模型后,就可以训练模型了,下面分别定义了数据处理和加载流程,模型,优化器,损失函数以及用准确率评估模型能力。

得到的结果如下

训练一次 可以看出比较混乱 没有说明规律可言

训练五次的损失函数如下 可见随着训练次数的增加是逐渐收敛的,规律也非常明显

 

准确率图像如下

最后 部分源码如下

import torch
import torchvision
import  torch.nn as nn
from torch import  optim
from tqdm import  tqdm
import torch.utils.data.dataset
mnist=torchvision.datasets.MNIST(root='~',train=True,download=True)
for i,j in enumerate(np.random.randint(0,len(mnist),(10,))):
    data,label=mnist[j]
    plt.subplot(2,5,i+1)
    plt.show()
trans=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,),(0.3081,))
    ]
)
normalized=trans(mnist[0][0])
from torchvision import  transforms
mnist=torchvision.datasets.MNIST(root='~',train=True,download=True,transform=trans)
def imshow(img):
    img=img*0.3081+0.1307
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
dataloader=DataLoader(mnist,batch_size=4,shuffle=True,num_workers=0)
images,labels=next(iter(dataloader))
imshow(torchvision.utils.make_grid(images))
class MLP(nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.inputlayer=nn.Sequential(nn.Linear(28*28,256),nn.ReLU(),nn.Dropout(0.2))
        self.hiddenlayer=nn.Sequential(nn.Linear(256,256),nn.ReLU(),nn.Dropout(0.2))
        self.outputlayer=nn.Sequential(nn.Linear(256,10))
    def forward(self,x):
        x=x.view(x.size(0),-1)
        x=self.inputlayer(x)
        x=self.hiddenlayer(x)
        x=self.outputlayer(x)
        return x
print(MLP())
trans=transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,),(0.3081,))
    ]
)
al=torchvision.datasets.MNIST(root='~',train=False,download=True,transform=trans)
trainloader=DataLoader(mnist_train,batch_size=16,shuffle=True,num_workers=0)
valloader=DataLoader(mnist_val,batch_size=16,shuffle=True,num_workers=0)
#模型
model=MLP()
#优化器
optimizer=oD(model.parameters(),lr=0.01,momentum=0.9)
#损失函数
celoss=nn.ssEntropyLoss()
best_acc=0
#计算准确率
def accuracy(pred,target):
    pred_label=torch.amax(pred,1)
    correct=sum(pred_label==target).to(torch.float)
    return correct,len(pred)
acc={'train':[],"val}
loss_all={'train':[],"val":[]}
for epoch in tqdm(range(5)):
    model.eval()
    numer_val,denumer_val,loss_tr=0.,0.,0.
    with torch.no_grad():
        for data,target in valloader:
            output=model(data)
            loss=celoss(output,target)
            loss_tr+=loss.data
            num,denum=accuracy(output,target)
            numer_val+=num
            denumer_val+=denum
    #设置为训练模式
    model.train()
    numer_tr,denumer_tr,loss_val=0.,0.,0.
    for data,target in trainloader:
        optizer.zero_grad()
        output=model(data)
        loss=celoss(output,target)
        loss_val+=loss.data
        loss.backward()
        optimer.step()
        num,denum=accuracy(output,target)
        numer_tr+=num
        denumer_tr+=denum
    loss_all['train'].append(loss_tr/len(trainloader))
    loss_all['val'].aend(lss_val/len(valloader))
    acc['train'].pend(numer_tr/denumer_tr)
    acc['val'].append(numer_val/denumer_val)
"""
plt.plot(loss_all['train'])
plt.plot(loss_all['val'])
"""
plt.plot(acc['train'])
plt.plot(acc['val'])
plt.show()

创作不易 觉得有帮助请点赞关注收藏~~~

相关文章
|
11天前
|
存储 SQL 安全
网络安全的盾牌:漏洞防护与加密技术的实战应用
【8月更文挑战第27天】在数字化浪潮中,信息安全成为保护个人隐私和企业资产的关键。本文深入探讨了网络安全的两大支柱——安全漏洞管理和数据加密技术,以及如何通过提升安全意识来构建坚固的防御体系。我们将从基础概念出发,逐步揭示网络攻击者如何利用安全漏洞进行入侵,介绍最新的加密算法和协议,并分享实用的安全实践技巧。最终,旨在为读者提供一套全面的网络安全解决方案,以应对日益复杂的网络威胁。
|
11天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 中的动态计算图:实现灵活的神经网络架构
【8月更文第27天】PyTorch 是一款流行的深度学习框架,它以其灵活性和易用性而闻名。与 TensorFlow 等其他框架相比,PyTorch 最大的特点之一是支持动态计算图。这意味着开发者可以在运行时定义网络结构,这为构建复杂的模型提供了极大的便利。本文将深入探讨 PyTorch 中动态计算图的工作原理,并通过一些示例代码展示如何利用这一特性来构建灵活的神经网络架构。
31 1
|
20天前
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
148 59
|
8天前
|
运维 安全 应用服务中间件
自动化运维的利器:Ansible入门与实战网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【8月更文挑战第30天】在当今快速发展的IT时代,自动化运维已成为提升效率、减少错误的关键。本文将介绍Ansible,一种流行的自动化运维工具,通过简单易懂的语言和实际案例,带领读者从零开始掌握Ansible的使用。我们将一起探索如何利用Ansible简化日常的运维任务,实现快速部署和管理服务器,以及如何处理常见问题。无论你是运维新手还是希望提高工作效率的资深人士,这篇文章都将为你开启自动化运维的新篇章。
|
8天前
|
Java
【实战演练】JAVA网络编程高手养成记:URL与URLConnection的实战技巧,一学就会!
【实战演练】JAVA网络编程高手养成记:URL与URLConnection的实战技巧,一学就会!
21 3
|
9天前
|
Java API UED
【实战秘籍】Spring Boot开发者的福音:掌握网络防抖动,告别无效请求,提升用户体验!
【8月更文挑战第29天】网络防抖动技术能有效处理频繁触发的事件或请求,避免资源浪费,提升系统响应速度与用户体验。本文介绍如何在Spring Boot中实现防抖动,并提供代码示例。通过使用ScheduledExecutorService,可轻松实现延迟执行功能,确保仅在用户停止输入后才触发操作,大幅减少服务器负载。此外,还可利用`@Async`注解简化异步处理逻辑。防抖动是优化应用性能的关键策略,有助于打造高效稳定的软件系统。
24 2
|
17天前
|
数据采集 存储 前端开发
豆瓣评分9.0!Python3网络爬虫开发实战,堪称教学典范!
今天我们所处的时代是信息化时代,是数据驱动的人工智能时代。在人工智能、物联网时代,万物互联和物理世界的全面数字化使得人工智能可以基于这些数据产生优质的决策,从而对人类的生产生活产生巨大价值。 在这个以数据驱动为特征的时代,数据是最基础的。数据既可以通过研发产品获得,也可以通过爬虫采集公开数据获得,因此爬虫技术在这个快速发展的时代就显得尤为重要,高端爬虫人才的收人也在逐年提高。
|
25天前
|
机器学习/深度学习 PyTorch TensorFlow
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
69 1
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch代码实现神经网络
这段代码示例展示了如何在PyTorch中构建一个基础的卷积神经网络(CNN)。该网络包括两个卷积层,分别用于提取图像特征,每个卷积层后跟一个池化层以降低空间维度;之后是三个全连接层,用于分类输出。此结构适用于图像识别任务,并可根据具体应用调整参数与层数。
|
4天前
|
SQL 安全 网络安全
网络安全之盾:漏洞防御与加密技术的实战应用
【9月更文挑战第2天】在数字时代的浪潮中,网络安全成为保护个人隐私和企业资产的坚固盾牌。本文深入探讨了网络安全的两个核心要素:防御漏洞和加密技术。我们将从基础概念入手,逐步剖析常见的网络攻击手段,并分享如何通过实践加强安全意识。同时,提供代码示例以增强理解,旨在为读者构建一道坚不可摧的网络安全防线。
下一篇
DDNS