基于 tensorflow 的手写数字的识别(进阶)
本系列将分为 8 篇 。本次为第 8 篇 ,基于 tensorflow ,利用卷积神经网络 CNN 进行手写数字识别 。
1.引言
关于 mnist 数据集的介绍和卷积神经网络的笔记在本系列文章中已有过介绍 ,有需要可见下述两篇文章 。本系列第 5 篇曾实现利用最简单的 BP 神经网络进行手写数字识别 。本系列第 6 篇简单介绍了下卷积神经网络的知识 。
2.设计的 CNN 结构
本系列第 4 讲讲过实战可以大致分为 "三步走"
- 定义神经网络的结构和前向传播的输出结果
- 定义损失函数以及选择反向传播优化的算法
- 生成会话(tf.Session) 并在训练数据上反复运行反向传播优化算法
这里也一样 ,当然首先是设计我们针对此实战的卷积神经网络 ,设计一个最简单的如下手绘 (还是那句话 ,字丑人帅 ,拒绝反驳)
上图得到两次卷积池化结果后 ,将结果展平为 1 维向量 ,即1 *(7*7*64),再连接到十个节点的输出层 。
3.手动干起来 !
首先 ,需要读取 MNIST 数据集 ,利用 TF 框架自带类进行下载读取 。
接下来就是根据之前的 “三步走” 进行实践 。实现上述的网络结构 ,并依旧选择二次代价函数和梯度下降法 。
首先 ,定义两个函数 ,用于初始化参数 。再定义两个函数实现卷积核池化(只是便于模块化 ,提高可读性)。
根据上述手绘结构图进行编程实现该结构 。
这里有一个 dropout 操作 ,目的是训练过程中使一部分神经元参数不变 ,即不参与训练 ,相当于简化结构 ,减少过拟合 。
再在会话 Session 中执行 ,并保存好模型参数 。
测试结果(小詹在按时付费的某服务器跑的结果)如下图 :
上述代码获取方式 ,后台回复关键词【S8】即可 。