Pytorch全连接神经网络实现手写数字识别

简介: Pytorch全连接神经网络实现手写数字识别

问题

Mnist手写数字识别数据集作为一个常见数据集,包含10个类别,在此次深度学习的过程中,我们通过pytorch提供的库函数,运用全连接神经网络实现手写数字的识别


方法

设置参数

input_size = 784
hidden_size = 500
output_size = 10
num_epochs = 5
batch_size = 100
l2earning_rate = 0.001

下载mnist数据集,并将其分为训练集和测试集

定义一个带有隐藏层的全连接神经网络

class NeuralNet(nn.Module):
    def__init__(self,input_size,hidden_size,output_size):
       super(NeuralNet, self).__init__()
       self.fc1 = nn.Linear(input_size, hidden_size)
       self.relu = nn.ReLU()
       self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
    out = self.fc1(x)
    out = self.relu(out)
    out = self.fc2(out)
    return out
model=NeuralNet(input_size,hidden_size,output_size).to(device)   #类的实例化

损失函数和优化算法

训练模型

total_step = len(train_loader)  #训练数据的大小,也就是含有多少个barch
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
       images = images.reshape(-1, 28*28).to(device)    
       labels = labels.to(device)
       outputs = model(images)
       loss = criterion(outputs, labels)
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       if (i+1) % 100 == 0:
           print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

测试模型

实验结果


结语

通过此次试验发现,在训练数据时,传入网络的是一个独立标签,即,我们希望输出的是2,但输出的不是用实数2做标签,而是用一个表示实数2的一个十维向量[0,0,1,0,0,0,0,0,0,0],对于分类问题,这种表示尤为重要。

目录
相关文章
|
2月前
|
传感器 运维 物联网
蓝牙Mesh网络:连接未来的智能解决方案
蓝牙Mesh网络:连接未来的智能解决方案
186 12
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 中的动态计算图:实现灵活的神经网络架构
【8月更文第27天】PyTorch 是一款流行的深度学习框架,它以其灵活性和易用性而闻名。与 TensorFlow 等其他框架相比,PyTorch 最大的特点之一是支持动态计算图。这意味着开发者可以在运行时定义网络结构,这为构建复杂的模型提供了极大的便利。本文将深入探讨 PyTorch 中动态计算图的工作原理,并通过一些示例代码展示如何利用这一特性来构建灵活的神经网络架构。
204 1
|
7天前
|
安全 网络架构
无线网络:连接未来的无形纽带
【10月更文挑战第13天】
41 8
|
20天前
|
存储 网络协议 Java
【网络】UDP回显服务器和客户端的构造,以及连接流程
【网络】UDP回显服务器和客户端的构造,以及连接流程
48 2
|
10天前
|
人工智能 安全 搜索推荐
|
14天前
|
监控 安全 5G
|
2月前
|
机器学习/深度学习
小土堆-pytorch-神经网络-损失函数与反向传播_笔记
在使用损失函数时,关键在于匹配输入和输出形状。例如,在L1Loss中,输入形状中的N代表批量大小。以下是具体示例:对于相同形状的输入和目标张量,L1Loss默认计算差值并求平均;此外,均方误差(MSE)也是常用损失函数。实战中,损失函数用于计算模型输出与真实标签间的差距,并通过反向传播更新模型参数。
|
22天前
|
安全 5G 网络安全
5G 网络中的认证机制:构建安全连接的基石
5G 网络中的认证机制:构建安全连接的基石
35 0
|
3月前
|
机器学习/深度学习 PyTorch 测试技术
深度学习入门:使用 PyTorch 构建和训练你的第一个神经网络
【8月更文第29天】深度学习是机器学习的一个分支,它利用多层非线性处理单元(即神经网络)来解决复杂的模式识别问题。PyTorch 是一个强大的深度学习框架,它提供了灵活的 API 和动态计算图,非常适合初学者和研究者使用。
45 0
|
3月前
|
Kubernetes 监控 Shell
在K8S中,我们公司用户反应pod连接数非常多,希望看一下这些连接都是什么信息?什么状态?怎么排查?容器里面没有集成bash环境、网络工具,怎么处理?
在K8S中,我们公司用户反应pod连接数非常多,希望看一下这些连接都是什么信息?什么状态?怎么排查?容器里面没有集成bash环境、网络工具,怎么处理?

热门文章

最新文章