基于PaddlePaddle框架经典的全连接神经网络实现手写数字图片识别

简介: 基于PaddlePaddle框架经典的全连接神经网络实现手写数字图片识别

使用经典的全连接神经网络实现手写数字图片识别


image.pngimage.png


库函数介绍


  • numpy    :python第三方库,用于进行科学计算
  • PIL      :Python Image Library,python第三方图像处理库
  • matplotlib:python的绘图库 pyplot:matplotlib的绘图框架
  • os        :提供了丰富的方法来处理文件和目录
#导入需要的包
import numpy as np
import paddle
import paddle.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import os
print("本教程基于Paddle的版本号为:"+paddle.__version__)


Step1:准备数据。


(1)数据集介绍

MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。

image.png

(2)transform函数是定义了一个归一化标准化的标准


(3)train_dataset和test_dataset


paddle.vision.datasets.MNIST()中的mode='train'和mode='test'分别用于获取mnist训练集和测试集


transform=transform参数则为归一化标准

#导入数据集Compose的作用是将用于数据集预处理的接口以列表的方式进行组合。
#导入数据集Normalize的作用是图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('下载并加载训练数据')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('加载完成')
下载并加载训练数据
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz 
Begin to download
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz 
Begin to download
........
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz 
Begin to download
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz 
Begin to download
..
Download finished
加载完成
# 输出图片
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
print(plt.imshow(train_data0, cmap=plt.cm.binary))
print('train_data0 的标签为: ' + str(train_label_0))
AxesImage(18,18;111.6x108.72)
train_data0 的标签为: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))
# print(train_data0)


Step2.网络配置


以下的代码判断就是定义一个简单的全连接神经网络,本示例一共有三层,两个大小为128的隐层和一个大小为10的输出层。


MNIST数据集是手写0到9的灰度图像,类别有10个,所以最后的输出大小是10。

本实践的网络结构为:输入层-->>隐层-->>隐层-->>输出层。

image.png

# 定义全连接神经网络
class Mnist(nn.Layer):
    def __init__(self):
        super(Mnist,self).__init__()
        self.fc1 = nn.Linear(in_features = 28*28, out_features = 256)
        self.fc2 = nn.Linear(in_features = 256, out_features =128)
        self.fc3 = nn.Linear(in_features = 128, out_features = 10)
    def forward(self, input): 
        bsz = input.shape[0]
        x = paddle.reshape(input,[bsz,-1])
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        y = self.fc3(x)
        return y
from paddle.metric import Accuracy
# 用Model封装模型
model = paddle.Model(Mnist())   
# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())
# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=2,batch_size=64,save_dir='model.pd',verbose=1)
W1204 13:02:37.934275   128 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1204 13:02:37.939237   128 device_context.cc:372] device: 0, cuDNN Version: 7.6.
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/2
step  30/938 [..............................] - loss: 0.4677 - acc: 0.6531 - ETA: 17s - 20ms/st
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:89: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
step  50/938 [>.............................] - loss: 0.6765 - acc: 0.7234 - ETA: 12s - 14ms/stepstep 938/938 [==============================] - loss: 0.2223 - acc: 0.9111 - 7ms/step        
save checkpoint at /home/aistudio/model.pd/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0849 - acc: 0.9543 - 6ms/step         
Eval samples: 10000
Epoch 2/2
step 938/938 [==============================] - loss: 0.0974 - acc: 0.9580 - 6ms/step         
save checkpoint at /home/aistudio/model.pd/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0055 - acc: 0.9571 - 6ms/step         
Eval samples: 10000
save checkpoint at /home/aistudio/model.pd/final


Step3.模型评估


res = model.evaluate(test_dataset,batch_size=64, verbose=1)
print("测试集损失值:{}".format(res["loss"]))
print("测试集准确率:{}".format(res["acc"]))
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0055 - acc: 0.9571 - 6ms/step         
Eval samples: 10000
测试集损失值:[0.0054676016]
测试集准确率:0.9571


Step4.模型预测


#获取测试集的第一个图片
test_data0, test_label_0 = test_dataset[0][0],test_dataset[0][1]
test_data0 = test_data0.reshape([28,28])
plt.figure(figsize=(2,2))
#展示测试集中的第一个图片
print(plt.imshow(test_data0, cmap=plt.cm.binary))
print('test_data0 的标签为: ' + str(test_label_0))
#模型预测
result = model.predict(test_dataset, batch_size=1)
#打印模型预测的结果
print('test_data0 预测的数值为:%d' % np.argsort(result[0][0])[0][-1])
AxesImage(18,18;111.6x108.72)
test_data0 的标签为: [7]
Predict begin...
step 10000/10000 [==============================] - 1ms/step        
Predict samples: 10000
test_data0 预测的数值为:7

image.png

目录
相关文章
|
1天前
|
Ubuntu
虚拟机Ubuntu连接不了网络的解决方法
虚拟机Ubuntu连接不了网络的解决方法
|
1天前
|
机器学习/深度学习 Python
【Python实战】——神经网络识别手写数字(三)
【Python实战】——神经网络识别手写数字
|
1天前
|
机器学习/深度学习 数据可视化 Python
【Python实战】——神经网络识别手写数字(二)
【Python实战】——神经网络识别手写数字(三)
|
1天前
|
存储 分布式计算 监控
Hadoop【基础知识 01+02】【分布式文件系统HDFS设计原理+特点+存储原理】(部分图片来源于网络)【分布式计算框架MapReduce核心概念+编程模型+combiner&partitioner+词频统计案例解析与进阶+作业的生命周期】(图片来源于网络)
【4月更文挑战第3天】【分布式文件系统HDFS设计原理+特点+存储原理】(部分图片来源于网络)【分布式计算框架MapReduce核心概念+编程模型+combiner&partitioner+词频统计案例解析与进阶+作业的生命周期】(图片来源于网络)
139 2
|
1天前
|
机器学习/深度学习 数据可视化 Python
【Python实战】——神经网络识别手写数字(一)
【Python实战】——神经网络识别手写数字
|
1天前
|
存储 网络协议 Linux
RTnet – 灵活的硬实时网络框架
本文介绍了开源项目 RTnet。RTnet 为以太网和其他传输媒体上的硬实时通信提供了一个可定制和可扩展的框架。 本文描述了 RTnet 的架构、核心组件和协议。
18 0
RTnet – 灵活的硬实时网络框架
|
1天前
|
机器学习/深度学习 PyTorch 算法框架/工具
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
|
1天前
|
运维 安全 网络架构
【专栏】NAT技术是连接私有网络与互联网的关键,缓解IPv4地址短缺,增强安全性和管理性
【4月更文挑战第28天】NAT技术是连接私有网络与互联网的关键,缓解IPv4地址短缺,增强安全性和管理性。本文阐述了五大NAT类型:全锥形NAT(安全低,利于P2P)、限制锥形NAT(增加安全性)、端口限制锥形NAT(更安全,可能影响协议)、对称NAT(高安全,可能导致兼容性问题)和动态NAT(公网IP有限时适用)。选择NAT类型需考虑安全性、通信模式、IP地址数量和设备兼容性,以确保网络高效、安全运行。
|
1天前
|
安全 网络协议 物联网
城域以太网:连接城市的高速网络
【4月更文挑战第22天】
22 0
|
1天前
|
数据可视化 网络协议

热门文章

最新文章