手写数字识别基本思路

简介: 手写数字识别基本思路

问题

什么是MNIST?如何使用Pytorch实现手写数字识别?如何进行手写数字对模型进行检验?


方法

mnist数据集

MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。


使用Pytorch实现手写数字识别

1.进行数据预处理对于MNIST数据集,可以通过torchvision中的datasets进行下载。

root (string):表示数据集的根目录,其中根目录存在MNIST/processed/training.pt和MNIST/processed/test.pt的子目录。

train (bool, optional):如果为True,则从training.pt创建数据集,否则从test.pt创建数据集。

download (bool, optional):如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载。

transform (callable, optional):接收PIL图片并返回转换后版本图片的转换函数。

bat = 128
transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(0.1307, 0.3081)  # (均值,方差)
])  # Compoes 两个操作合为一个
train_ds = datasets.MNIST(root='data', download=False, train=True,
                         transform=transform)
train_ds, val_ds = torch.utils.data.random_split(train_ds, [50000, 10000])
test_ds = datasets.MNIST(root='data', download=True, train=False,
                        transform=transform)
train_loader = DataLoader(dataset=train_ds, batch_size=bat, shuffle=True)
val_loader = DataLoader(dataset=val_ds, batch_size=bat)
test_loader = DataLoader(dataset=test_ds, batch_size=bat)


2.构建模型

class MyNet(nn.Module):
   def __init__(self) -> None:
       super().__init__()

       self.flatten = nn.Flatten()  # 将28*28的图像拉伸为784维向量
       # 第一个全连接层Full Connection(FC)
       self.fc1 = nn.Linear(in_features=784,
                            out_features=256)
       self.fc2 = nn.Linear(in_features=256,
                            out_features=128)
       self.fc3 = nn.Linear(in_features=128,
                            out_features=10)

   def forward(self, x):
       x = self.flatten(x)
       x = torch.relu(self.fc1(x))
       x = torch.relu(self.fc2(x))
       out = torch.relu(self.fc3(x))
       return out

构建一个三层的神经网络MNIST数据集中的图片都是28×28大小的,而且是灰度图。而全连接神经网络的输入要是一个行向量,所以我们要把28×28的矩阵转换成28×28=764的行向量,作为神经网络的输入


3.优化器的选择,参数设置

使用优化器和损失函数。优化器选择SGD,SGD随机梯度下降,lr学习率取值0.2最优,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生。

optimizer=torch.optim.SGD(net.parameters(),lr=0.2)#lr学习率,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生
#损失函数
#衡量yy_hat之间的差异
loss_fn=nn.CrossEntropyLoss()


4.对模型进行训练测试,网络的输入,输入尺寸B*C*H*W B是batch,一个batch一个batch交给网络处理,x=torch.rand(size=(128,1,28,28)),基于loss信息利用优化器从后向前更新网络全部参数。

def train(dataloader, net, loss_fn, optimizer, epoch):
   size = len(dataloader.dataset)
   corrent = 0
   epoch_loss = 0.0
   batch_num = len(dataloader)
   net.train()

   # 一个batch一个batch的训练网络
   for batch_idx, (X, y) in enumerate(dataloader):
       pred = net(X)

       # 衡量y与y_hat之间的loss
       # y:128, pred:128x10 CrossEntropyloss
       loss = loss_fn(pred, y)

       # 基于loss信息利用优化器从后向前更新网络全部参数 <---
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       epoch_loss += loss.item()
       corrent += (pred.argmax(1) == y).type(torch.float).sum().item()
       if batch_idx % 100 == 0:
           # f-string
           print(f'[{batch_idx + 1:>5d}/{batch_num + 1:>5d}],loss:{loss.item()}')
   avg_loss = epoch_loss / batch_num
   avg_accuracy = corrent / size
   # loss_list.append(avg_loss)
   return avg_accuracy, avg_loss
def test(dataloader, net, loss_fn):
   size = len(dataloader.dataset)
   batch_num = len(dataloader)
   corrent = 0
   losses = 0
   net.eval()
   with torch.no_grad():
       for X, y in test_loader:
           pred = net(X)
           correct = (pred.argmax(1) == y).type(torch.int).sum().item()
           # print(y.size(0))
           # print(correct)
           corrent += correct
   accuracy = corrent / size
   avg_loss = losses / batch_num
   return accuracy, avg_loss



5.保存最优的模型

net.load_state_dict(torch.load('model_best.pth'))
   test(test_loader,net,loss_fn)


6.读入自己的写入数字,进行识别

model = MyNet()
model.load_state_dict(torch.load('model_best.pth'))
img = Image.open("7.png").convert("L")  # 转为灰度图像
img = transform(img)
# img = np.array(img)
# print(img)
result = model(img)
_, predict = torch.max(result.data, dim=1)
print(result)
print("the result is:",predict.item())


结语

minist是一个28*28的图像,所以输入就是28*28=784的维度,输出为10,0-9十个数字。手写数字识别首先需要初始化全局变量,构建数据集。然后构建模型,构建迭代器与损失函数,进行训练测试。最后可以将训练的模型进行保存,通过读取自己写的数字进行识别验证,完成一个简单的深度学习。

目录
相关文章
|
存储 SQL 人工智能
新年将至,为大家推荐一款开源AI红包封面制作神器AiCover!
新年将至,为大家推荐一款开源AI红包封面制作神器AiCover!
331 2
|
JSON JavaScript 数据格式
jwt-auth插件实现了基于JWT(JSON Web Tokens)进行认证鉴权的功能。
jwt-auth插件实现了基于JWT(JSON Web Tokens)进行认证鉴权的功能。
305 1
|
7月前
|
运维 安全 关系型数据库
Websoft9 运维面板,全网真正的一键部署应用
Websoft9运维面板实现应用真·一键部署,通过智能环境适配、安全架构与容器化技术,将传统数小时部署缩短至分钟级,显著提升效率与安全性。
204 5
|
11月前
|
编解码 前端开发 JavaScript
深入探讨 PostCSS 的特点、优势以及在实际开发中的应用
PostCSS是一款用JavaScript实现的CSS处理工具,通过丰富的插件生态,支持代码优化、格式化、兼容性处理及性能提升,极大提升了前端开发效率和代码质量。它高度可定制,易于集成现有工作流,适用于大型项目和复杂设计需求。
164 3
|
存储 Linux 调度
深入理解操作系统:从理论到实践
【9月更文挑战第32天】本文将带你深入了解操作系统的基本原理和实践应用。我们将从操作系统的定义开始,探讨它的基本功能和组件,然后深入到进程管理、内存管理、文件系统等核心概念。最后,我们将通过一个简单的代码示例来展示如何在实际操作中应用这些理论知识。无论你是计算机专业的学生,还是对操作系统感兴趣的开发者,这篇文章都将为你提供有价值的参考。
|
机器学习/深度学习 算法
【MATLAB】基于EMD-PCA-LSTM的回归预测模型
【MATLAB】基于EMD-PCA-LSTM的回归预测模型
495 0
【MATLAB】基于EMD-PCA-LSTM的回归预测模型
|
人工智能 开发者
Kimi Chat:国内AI新星,20万字超长文本处理的突破者
【2月更文挑战第12天】Kimi Chat:国内AI新星,20万字超长文本处理的突破者
3149 2
Kimi Chat:国内AI新星,20万字超长文本处理的突破者
|
缓存 小程序 前端开发
谈谈钉钉工作台的体验优化及技术思考
本文主要介绍本次体验优化专项的特点,产品能力体验升级背后的技术思考,以及技术视角优化的关键策略和结果。
|
SQL 运维 Java
HBASE启动脚本/Shell解析
常用到的HBase启动脚本有: 1.$HBASE_HOME/bin/start-hbase.sh 启动整个集群 2.$HBASE_HOME/bin/stop-hbase.sh 停止整个集群 3.$HBASE_HOME/bin/hbase-daemons.sh 启动或停止,所有的regionserver或zookeeper或backup-master 4.$HBASE_HOME/bin/hbase-daemon.sh 启动或停止,单个master或regionserver或zookeeper 以start-hbase.sh为起点,可以看看脚本间的一些调用关系
985 0
layui 请求返回401token 过期 重新登陆
layui 请求返回401token 过期 重新登陆
275 0