基于tensorflow+DNN的MNIST数据集手写数字分类

简介: 2018年9月17日笔记tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。DNN是deep neural network的简称,中文叫做深层神经网络,有时也叫做多层感知机(Multi-Layer perceptron,MLP)。

2018年9月17日笔记

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。
DNN是deep neural network的简称,中文叫做深层神经网络,有时也叫做多层感知机(Multi-Layer perceptron,MLP)。
从DNN按不同层的位置划分,DNN内部的神经网络层可以分为三类,输入层,隐藏层和输出层。
如下图示例,一般来说第一层是输入层,最后一层是输出层,而中间的层数都是隐藏层。

img_ff2e5dc3e95249daf43922313465f706.png
image.png

MNIST是Mixed National Institue of Standards and Technology database的简称,中文叫做 美国国家标准与技术研究所数据库
此文在上一篇文章《基于tensorflow的MNIST数据集手写数字分类预测》的基础上添加了1个隐藏层,模型准确率从91%提升到98%
《基于tensorflow的MNIST数据集手写数字分类预测》文章链接: https://www.jianshu.com/p/135c21e3db73

0.编程环境

安装tensorflow命令:pip install tensorflow
操作系统:Win10
tensorflow版本:1.6
tensorboard版本:1.6
python版本:3.6

1.致谢声明

1.本文是作者学习《周莫烦tensorflow视频教程》的成果,感激前辈;
视频链接:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/
2.参考云水木石的文章,链接:https://mp.weixin.qq.com/s/H9I0KX0CBkHeap5Xpwp-5Q

2.下载并解压数据集

MNIST数据集下载链接: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密码: wa9p
下载压缩文件MNIST_data.rar完成后,选择解压到当前文件夹不要选择解压到MNIST_data。
文件夹结构如下图所示:

img_644ec67b5b6dd7b9c15370e35245b53d.png
image.png

3.完整代码

此章给读者能够直接运行的完整代码,使读者有编程结果的感性认识。
如果下面一段代码运行成功,则说明安装tensorflow环境成功。
想要了解代码的具体实现细节,请阅读后面的章节。
在迭代训练5000次后,模型的准确率可以到达98%左右,下面代码为了节省读者运行时间,只迭代训练1000次。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)

def addConnect(inputs, in_size, out_size, activation_function=None):
    Weights = tf.Variable(tf.truncated_normal([in_size, out_size], stddev=0.01))
    biases = tf.Variable(tf.zeros([1, out_size]))
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    if activation_function is None:
        return Wx_plus_b
    else:
        return activation_function(Wx_plus_b)

connect_1 = addConnect(X_holder, 784, 300, tf.nn.relu)
predict_y = addConnect(connect_1, 300, 10, tf.nn.softmax)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.AdagradOptimizer(0.3)
train = optimizer.minimize(loss)

session = tf.Session()
init = tf.global_variables_initializer()
session.run(init)

for i in range(1000):
    images, labels = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:images, y_holder:labels})
    if i % 50 == 0:
        correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
        print('step:%d accuracy:%.4f' %(i, accuracy_value))

第12行代码tf.truncated_normal方法与tf.random_normal方法的区别如下图所示。
truncated中文叫做被切去顶端的,tf.truncated_normal方法产生的随机数都处于均值两边2个标准差之内。

img_6908f6ce3b35e296a40ced05f2ed4286.png
image.png

上面一段代码的运行结果如下:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
step:0 accuracy:0.4195
step:50 accuracy:0.8827
step:100 accuracy:0.9144
step:150 accuracy:0.9175
step:200 accuracy:0.9391
step:250 accuracy:0.9422
step:300 accuracy:0.9401
step:350 accuracy:0.9550
step:400 accuracy:0.9581
step:450 accuracy:0.9568
step:500 accuracy:0.9531
step:550 accuracy:0.9618
step:600 accuracy:0.9601
step:650 accuracy:0.9586
step:700 accuracy:0.9599
step:750 accuracy:0.9651
step:800 accuracy:0.9673
step:850 accuracy:0.9691
step:900 accuracy:0.9701
step:950 accuracy:0.9667

从上面的运行结果可以看出,经过1000次迭代训练,模型准确率到达0.9667左右。

4.数据准备

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)

第1行代码导入warnings库,第2行代码表示不打印警告信息;
第3行代码导入tensorflow库,取别名tf;
第4行代码人从tensorflow.examples.tutorials.mnist库中导入input_data文件;
本文作者使用anaconda集成开发环境,input_data文件所在路径:C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow\examples\tutorials\mnist,如下图所示:

img_4e71145fa44d78d15c3773b8a5965a46.png
image.png

第6行代码调用input_data文件的read_data_sets方法,需要2个参数,第1个参数的数据类型是字符串,是读取数据的文件夹名,第2个关键字参数ont_hot数据类型为布尔bool,设置为True,表示预测目标值是否经过One-Hot编码;
第7行代码定义变量batch_size的值为100;
第8、9行代码中placeholder中文叫做占位符,将每次训练的特征矩阵X和预测目标值y赋值给变量X_holder和y_holder。

5.数据观察

本章内容主要是了解变量mnist中的数据内容,并掌握变量mnist中的方法使用。

5.1 查看变量mnist的方法和属性

dir(mnist)[-10:]

上面一段代码的运行结果如下:

['_asdict',
'_fields',
'_make',
'_replace',
'_source',
'count',
'index',
'test',
'train',
'validation']

为了节省篇幅,只打印最后10个方法和属性。
我们会用到的是其中test、train、validation这3个方法。

5.2 对比三个集合

train对应训练集,validation对应验证集,test对应测试集。
查看3个集合中的样本数量,代码如下:

print(mnist.train.num_examples)
print(mnist.validation.num_examples)
print(mnist.test.num_examples)

上面一段代码的运行结果如下:

55000
5000
10000

对比3个集合的方法和属性


img_46e9f63efbae404dd52b70e39a6ea73f.png
image.png

从上面的运行结果可以看出,3个集合的方法和属性基本相同。
我们会用到的是其中images、labels、next_batch这3个属性或方法。

5.3 mnist.train.images观察

查看mnist.train.images的数据类型和矩阵形状。

images = mnist.train.images
type(images), images.shape

上面一段代码的运行结果如下:

(numpy.ndarray, (55000, 784))

从上面的运行结果可以看出,在变量mnist.train中总共有55000个样本,每个样本有784个特征。
原图片形状为28*28,28*28=784,每个图片样本展平后则有784维特征。
选取1个样本,用3种作图方式查看其图片内容,代码如下:

import matplotlib.pyplot as plt

image = mnist.train.images[1].reshape(-1, 28)
plt.subplot(131)
plt.imshow(image)
plt.axis('off')
plt.subplot(132)
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.subplot(133)
plt.imshow(image, cmap='gray_r')
plt.axis('off')
plt.show()

上面一段代码的运行结果如下图所示:

img_f69715f9b2e6ca9110961b6a878d39a4.png
image.png

从上面的运行结果可以看出,调用plt.show方法时,参数cmap指定值为 graygray_r符合正常的观看效果。

5.4 查看手写数字图

从训练集mnist.train中选取一部分样本查看图片内容,即调用mnist.train的next_batch方法随机获得一部分样本,代码如下:

import matplotlib.pyplot as plt
import math
import numpy as np

def drawDigit(position, image, title):
    plt.subplot(*position)
    plt.imshow(image.reshape(-1, 28), cmap='gray_r')
    plt.axis('off')
    plt.title(title)
    
def batchDraw(batch_size):
    images,labels = mnist.train.next_batch(batch_size)
    image_number = images.shape[0]
    row_number = math.ceil(image_number ** 0.5)
    column_number = row_number
    plt.figure(figsize=(row_number, column_number))
    for i in range(row_number):
        for j in range(column_number):
            index = i * column_number + j
            if index < image_number:
                position = (row_number, column_number, index+1)
                image = images[index]
                title = 'actual:%d' %(np.argmax(labels[index]))
                drawDigit(position, image, title)

batchDraw(196)
plt.show()

上面一段代码的运行结果如下图所示,本文作者对难以辨认的数字做了红色方框标注:


img_03e17fd40d91b5bcf8c21e518bdbfaad.png
image.png

6.搭建神经网络

def addConnect(inputs, in_size, out_size, activation_function=None):
    Weights = tf.Variable(tf.random_normal([in_size, out_size], stddev=0.01))
    biases = tf.Variable(tf.zeros([1, out_size]))
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    if activation_function is None:
        return Wx_plus_b
    else:
        return activation_function(Wx_plus_b)

connect_1 = addConnect(X_holder, 784, 300, tf.nn.relu)
predict_y = addConnect(connect_1, 300, 10, tf.nn.softmax)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.AdagradOptimizer(0.3)
train = optimizer.minimize(loss)

第1-8行代码定义addConnect函数,即在神经网络中添加1个连接层;
addConnect函数需要4个参数,第1个参数是输入层矩阵Inputs;
第2个参数是连接上一层神经元个数in_size,数据类型为整数;
第3个参数是连接下一层神经元个数,数据类型为整数;
第4个参数是激活函数。数据类型为函数对象。
第10行代码添加第1个连接层,并将其输出结果赋值给变量connect_1;
第11行代码添加第2个连接层,并将其输出结果赋值给变量predict_y,即标签预测值;
第12行代码定义损失函数loss,因为是多分类问题,使用交叉熵作为损失函数,tf.reduce_sum函数的第2个参数为1的原因是表示对行求和, 如果第2个参数为0节表示对列求和。
第13行代码定义优化器optimizer,作者使用过GradientDescentOptimizer、AdamOptimizer,经过实践对比,AdagradOptimizer在此问题的收敛效果较好,读者可以自己尝试设置不同的优化的效果;
第14行代码定义训练过程,即用优化器最小化损失。

7.变量初始化

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

对于神经网络模型,重要是其中的W、b这两个参数。
开始神经网络模型训练之前,这两个变量需要初始化。
第1行代码调用tf.global_variables_initializer实例化tensorflow中的Operation对象。


img_eba6278c89eb8ef15ee4daee0eaab711.png
image.png

第2行代码调用tf.Session方法实例化会话对象;
第3行代码调用tf.Session对象的run方法做变量初始化。

8.模型训练

for i in range(1000):
    images, labels = mnist.train.next_batch(batch_size)
    session.run(train, feed_dict={X_holder:images, y_holder:labels})
    if i % 50 == 0:
        correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
        print('step:%d accuracy:%.4f' %(i, accuracy_value))

第1行代码表示模型迭代训练1000次;
第2行代码调用mnist.train对象的next_batch方法,选出数量为batch_size的样本;
第3行代码是模型训练,每运行1次此行代码,即模型训练1次;
第4-8行代码是每隔25次训练打印模型准确率。

9.模型测试

import math
import matplotlib.pyplot as plt
import numpy as np

def drawDigit2(position, image, title, isTrue):
    plt.subplot(*position)
    plt.imshow(image.reshape(-1, 28), cmap='gray_r')
    plt.axis('off')
    if not isTrue:
        plt.title(title, color='red')
    else:
        plt.title(title)
        
def batchDraw2(batch_size):
    images,labels = mnist.test.next_batch(batch_size)
    predict_labels = session.run(predict_y, feed_dict={X_holder:images, y_holder:labels})
    image_number = images.shape[0]
    row_number = math.ceil(image_number ** 0.5)
    column_number = row_number
    plt.figure(figsize=(row_number+8, column_number+8))
    for i in range(row_number):
        for j in range(column_number):
            index = i * column_number + j
            if index < image_number:
                position = (row_number, column_number, index+1)
                image = images[index]
                actual = np.argmax(labels[index])
                predict = np.argmax(predict_labels[index])
                isTrue = actual==predict
                title = 'actual:%d\npredict:%d' %(actual,predict)
                drawDigit2(position, image, title, isTrue)

batchDraw2(100)
plt.show()

上面一段代码的运行结果如下图所示:


img_7a6d919f998e917211c3be72a5ce01a8.png
image.png

从上面的运行结果可以看出,100个数字中只错了3个,符合前1章准确率为97%左右的计算结果。

10.结论

1.这是本文作者写的第5篇关于tensorflow的文章,加深了对tensorflow框架的理解;
2.通过代码实践,本文作者掌握了调整学习率和权重初始化的要点和技巧;

目录
相关文章
|
1月前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
39 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
13天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
39 3
|
3月前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
79 1
|
3月前
|
机器学习/深度学习 API 算法框架/工具
【Tensorflow+keras】Keras API三种搭建神经网络的方式及以mnist举例实现
使用Keras API构建神经网络的三种方法:使用Sequential模型、使用函数式API以及通过继承Model类来自定义模型,并提供了基于MNIST数据集的示例代码。
56 12
|
3月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
94 0
|
3月前
|
机器学习/深度学习 TensorFlow 数据处理
分布式训练在TensorFlow中的全面应用指南:掌握多机多卡配置与实践技巧,让大规模数据集训练变得轻而易举,大幅提升模型训练效率与性能
【8月更文挑战第31天】本文详细介绍了如何在Tensorflow中实现多机多卡的分布式训练,涵盖环境配置、模型定义、数据处理及训练执行等关键环节。通过具体示例代码,展示了使用`MultiWorkerMirroredStrategy`进行分布式训练的过程,帮助读者更好地应对大规模数据集与复杂模型带来的挑战,提升训练效率。
82 0
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+Keras】keras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例
如何使用TensorFlow和Keras实现条件生成对抗网络(CGAN)并以MNIST和Fashion MNIST数据集为例进行演示。
50 3
|
3月前
|
机器学习/深度学习 PyTorch TensorFlow
【机器学习】基于tensorflow实现你的第一个DNN网络
【机器学习】基于tensorflow实现你的第一个DNN网络
59 0
|
3月前
|
机器学习/深度学习 API TensorFlow
【Tensorflow+keras】解决 Fail to find the dnn implementation.
在TensorFlow 2.0环境中使用双向长短期记忆层(Bidirectional LSTM)遇到“Fail to find the dnn implementation”错误时的三种解决方案。
71 0
|
6月前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow、Keras 和 Python 构建神经网络分析鸢尾花iris数据集|代码数据分享
TensorFlow、Keras 和 Python 构建神经网络分析鸢尾花iris数据集|代码数据分享

热门文章

最新文章

下一篇
无影云桌面