手把手,74行代码实现手写数字识别-阿里云开发者社区

开发者社区> 小旋风柴进> 正文

手把手,74行代码实现手写数字识别

简介:
+关注继续查看

0?wx_fmt=jpeg

1、 引言:不要站在岸上学游泳


“机器学习”是一个很实践的过程。就像刚开始学游泳,你在只在岸上比划一堆规定动作还不如先跳到水里熟悉水性学习来得快。以我们学习“机器学习”的经验来看,很多高大上的概念刚开始不懂也没关系,先写个东西来跑跑,有个感觉了之后再学习那些概念和理论就快多了。如果别人已经做好了轮子,直接拿过来用则更快。因此,本文直接用Michael Nielsen先生的代码(github地址: https://github.com/mnielsen/neural-networks-and-deep-learning)作为例子,给大家展现神经网络分析的普遍过程:导入数据,训练模型,优化模型,启发式理解等。


本文假设大家已经了解python的基本语法,并在自己机器上运行过简单python脚本。


2、 我们要解决的问题:手写数字识别


手写数字识别是机器学习领域中一个经典的问题,是一个看似对人类很简单却对程序十分复杂的问题。很多早期的验证码就是利用这个特点来区分人类和程序行为的,当然此处就不提12306近乎反人类的奇葩验证码了。


0?wx_fmt=jpeg

回到手写数字识别,比如我们要识别出一个手写的“9”,人类可能通过识别“上半部分一个圆圈,右下方引出一条竖线”就能进行判断。但用程序表达就似乎很困难了,你需要考虑非常多的描述方式,考虑非常多的特殊情况,最终发现程序写得非常复杂而且效果不好。


而用(机器学习)神经网络的方法,则提供了另一个思路:获取大量的手写数字的图像,并且已知它们表示的是哪个数字,以此为训练样本集合,自动生成一套模型(如神经网络的对应程序),依靠它来识别新的手写数字。


0?wx_fmt=jpeg 


本文中采用的数据集就是著名的“MNIST数据集”。它的收集者之一是人工智能领域著名的科学家——Yann LeCu。这个数据集有60000个训练样本数据集和10000个测试用例。运用本文展示的单隐层神经网络,就可以达到96%的正确率。


3、图解:解决问题的思路


我们可以用下图展示上面的粗略思路。


0?wx_fmt=jpeg 


但是如何由“训练集”来“生成模型”呢?


在这里我们使用反复推荐的逆推法——假设这个模型已经生成了,它应该满足什么样的特性,再以此特性为条件反过来求出模型。


可以推想而知,被生成的模型应该对于训练集的区分效果非常好,也就是相应的训练误差非常低。比如有一个未知其相应权重和偏移的神经网络,而训练神经网络的过程就是逐步确定这些未知参数的过程,最终使得这些参数确定的模型在训练集上的误差达到最小值。我们将会设计一个数量指标衡量这个误差,如果训练误差没有达到最小,我们将继续调整参数,直到这个指标达到最小。但这样训练出来的模型我们仍无法保证它面对新的数据仍会有这样好的识别效果,就需要用测试集对模型进行考核,得出的测试结果作为对模型的评价。因此,上图就可以细化成下图:



0?wx_fmt=jpeg 


但是,如果我们已经生成了多个模型,怎么从中选出最好的模型?一个自然的思路就是通过比较不同模型在测试集上的误差,挑选出误差最小的模型。这个想法看似没什么问题,但是随着你测试的模型增多,你会觉得用测试集筛选出来的模型也不那么可信。比如我们增加一个神经网络的隐藏层节点,就会产生新的对应权重,产生一个新的模型。但是我也不知道增加多少个节点是合适的,所以比较全面的想法就是尝试测试不同的节点数x∈(1,2,3,4,…,100), 来观察这些不同模型的测试误差,并挑出误差最小的模型。这时我们发现我们的模型其实多出来了一个参数x, 我们挑选模型的过程就是确定最优化的参数x 的过程。这个分析过程与上面训练参数的思路如出一辙!只是这个过程是基于同一个测试集,而不训练集。那么,不同的神经网络的层数是不是也是一个新的参数y∈(1,2,3,4,…,100), 也要经过这么个过程来“训练”?


我们会发现我们之前生成模型过程中很多不变的部分其实都是可以变换调节的,这些也是新的参数,比如训练次数、梯度下降过程的步长、规范化参数、学习回合数、minibatch 值等等,我们把他们叫做超参数。超参数是影响所求参数最终取值的参数,是机器学习模型里面的框架参数,可以理解成参数的参数,它们通常是手工设定,不断试错调整的,或者对一系列穷举出来的参数组合一通进行枚举(网格搜索)来确定。但无论如何,这也是基于同样一个数据集反复验证优化的结果。在这个数据集上最后的结果并不一定在新的数据继续有效。所以为了评估这个模型的识别效果,就需要用新的测试集对模型进行考核,得出的测试结果作为对模型的评价。这个新的测试集我们就直接叫“测试集”,之前那个用于筛选超参数的测试集,我们就叫做“交叉验证集”。筛选模型的过程其实就是交叉验证的过程。


所以,规范的方法的是将数据集拆分成三个集合:训练集、交叉验证集、测试集,然后依次训练参数、超参数,最终得到最优的模型。

因此,上图可以进一步细化成下图:



0?wx_fmt=jpeg


或者下图:


0?wx_fmt=jpeg
可见机器学习过程是一个反复迭代不断优化的过程。其中很大一部分工作是在调整参数和超参数。


4、先跑跑再说:初步运行代码


Michael Nielsen的代码封装得很好,只需以下5行命令就可以生成神经网络并测试结果,并达到94.76%的正确率!


0?wx_fmt=png


第一个命令的功能是:将数据集拆分成三个集合:训练集、交叉验证集、测试集。


第二个命令的功能是:生成神经网络对象,神经网络结构为三层,每层节点数依次为(784, 30, 10)。


第三个命令的功能是:用(mini-batch)梯度下降法训练神经网络(权重与偏移),并生成测试结果。


该命令设定了三个超参数:训练回合数=30, 用于随机梯度下降法的最小样本数(mini-batch-size)=10,步长=3.0。


总共的输出结果如下:


0?wx_fmt=png


5、神经网络如何识别手写数字:启发式理解


首先,我们解释一下神经网络每层的功能。



0?wx_fmt=jpeg


第一层是输入层。因为mnist数据集中每一个手写数字样本是一个28*28像素的图像,因此对于每一个样本,其输入的信息就是每一个像素对应的灰度,总共有28*28=784个像素,故这一层有784个节点。


第三层是输出层。因为阿拉伯数字总共有10个,我们就要将样本分成10个类别,因此输出层我们采用10个节点。当样本属于某一类(某个数字)的时候,则该类(该数字)对应的节点为1,而剩下9个节点为0,如[0,0,0,1,0,0,0,0,0,0]。


因此,我们每一个样本(手写数字的图像)可以用一个超长的784维的向量表示其特征,而用一个10维向量表示该样本所属的类别(代表的真实数字),或者叫做标签。


mnist的数据就是这样表示的。所以,如果你想看训练集中第n个样本的784维特征向量,直接看training_data[n][0]就可以找到,而要看其所属的标签,看training_data[n][1]就够了。


那么,第二层神经网络所代表的意义怎么理解?这其实是很难的。但是我们可以有一个启发式地理解,比如用中间层的某一个节点表示图像中的某一个小区域的特定图像。这样,我们可以假设中间层的头4个节点依次用来识别图像左上、右上、左下、右下4个区域是否存在这样的特征的。


0?wx_fmt=png

如果这四个节点的值都很高,说明这四个区域同时满足这些特征。将以上的四个部分拼接起来,我们会发现,输入样本很可能就是一个手写“0”!


0?wx_fmt=jpeg 


因此,同一层的几个神经元同时被激活了意味着输入样本很可能是某个数字。


当然,这只是对神经网络作用机制的一个启发式理解。真实的过程却并不一定是这样。但通过启发式理解,我们可以对神经网络作用机制有一个更加直观的认识。


由此可见,神经网络能够识别手写数字的关键是它有能够对特定的图像激发特定的节点。而神经网络之所以能够针对性地激发这些节点,关键是它具有能够适应相关问题场景的权重和偏移。那这些权重和偏移如何训练呢?


6、神经网络如何训练:进一步阅读代码


上文已经图解的方式介绍了机器学习解决问题的一般思路,但是具体到神经网络将是如何训练呢?


其实最快的方式是直接阅读代码。我们将代码的结构用下图展示出来,运用其内置函数名表示基本过程,发现与我们上文分析的思路一模一样:


0?wx_fmt=jpeg


简单解释一下,在神经网络模型中:


  • 所需要求的关键参数就是:神经网络的权重(self.weights)和偏移(self.biases)。

  • 超参数是:隐藏层的节点数=30,训练回合数(epochs)=30, 用于随机梯度下降法的最小样本数(mini_batch_size)=10,步长(eta)=3.0。

  • 用随机梯度下降法调整参数: 0?wx_fmt=jpeg

  • 用反向传播法求出随机梯度下降法所需要的梯度(偏导数): backprop()

  • 用输出向量减去标签向量衡量训练误差:cost_derivative() = output_activations-y


全部代码如下(去掉注释之后,只有74行):

0?wx_fmt=png
0?wx_fmt=png
0?wx_fmt=png

0?wx_fmt=png
0?wx_fmt=png


7、神经网络如何优化:训练超参数与多种模型对比


由以上分析可知,神经网络只需要74行代码就可以完成编程,可见机器学习真正困难的地方并不在编程,而在你对数学过程本身,和对它与现实问题的对应关系有深入的理解。理解深入后,你才能写出这样的程序,并对其进行精微的调优。


我们初步的结果已经是94.76%的正确率了。但如果要将准确率提得更高怎么办?


这其实是一个开放的问题,有许多方法都可以尝试。我们这里仅仅是抛砖引玉。


首先,隐藏层只有30个节点。由我们之前对隐藏层的启发式理解可以猜测,神经网络的识别能力其实与隐藏层对一些细节的识别能力正相关。如果隐藏层的节点更多的话,其识别能力应该会更强的。那么我们设定100个隐藏层节点试试?


0?wx_fmt=png


发现结果如下:


0?wx_fmt=png


发现,我们只是改了一个超参数,准确率就从94.76%提升到96.72%!


这里强调一下,更加规范的模型调优方法是将多个模型用交叉验证集的结果来横向比较,选出最优模型后再用一个新的测试集来最终评估该模型。本文为了与之前的结果比较,才采用了测试集而不是交叉验证集。读者千万不要学博主这样做哈,因为这很有可能会过拟合。这是工程实践中数据挖掘人员经常犯的错误,我们之后会专门写篇博文探讨。


我们现在回来继续调优我们的模型。那么还有其他的隐藏节点数更合适吗?这个我们也不知道。常见的方法是用几何级数增长的数列(如:10,100,1000,……)去尝试,然后不断确定合适的区间,最终确定一个相对最优的值。

但是即便如此,我们也只尝试了一个超参数,还有其他的超参数没有调优呢。我们于是尝试另一个超参数:步长。之前的步长是3.0,但是我们可能觉得学习速率太慢了。那么尝试一个更大的步长试试?比如100?


0?wx_fmt=png

发现结果如下:


0?wx_fmt=png


发现准确率低得不忍直视,看来步长设得太长了。根本跑不到最低点。那么我们设定一个小的步长试试?比如0.01。


0?wx_fmt=png


结果如下:


0?wx_fmt=png


呃,发现准确率同样低得不忍直视。但是有一个优点,准确率是稳步提升的。说明模型在大方向上应该还是对的。如果在调试模型的时候忽视了这个细节,你可能真的找不到合适的参数。


可见,我们第一次尝试的神经网络结构的超参数设定还是比较不错的。但是真实的应用场景中,基本没有这样好的运气,很可能刚开始测试出来的结果全是奇葩生物,长得违反常理,就像来自另一个次元似的。这是数据挖掘工程师常见的情况。此时最应该做的,就是遏制住心中数万草泥马的咆哮奔腾,静静地观察测试结果的分布规律,尝试找到些原因,再继续将模型试着调优下去,与此同时,做好从一个坑跳入下一个坑的心理准备。当然,在机器学习工程师前赴后继的填坑过程中,还是总结出了一些调优规律。我们会在接下来专门写博文分析。


当然,以上的调优都没有逃出神经网络模型本身的范围。但是可不可能其他的模型效果更好?比如传说中的支持向量机?关于支持向量机的解读已经超越了本文的篇幅,我们也考虑专门撰写博文分析。但是在这里我们只是引用一下在scikit-learn中提供好的接口,底层是用性能更好的C语言封装的著名的LIBSVM。

相关代码也在Michael Nielsen的文件中。直接引入,并运行一个方法即可。


0?wx_fmt=png

我们看看结果:


0?wx_fmt=png


94.35%,好像比我们的神经网络低一点啊。看来我们的神经网络模型还是更优秀一些?


然而,实际情况并非如此。因为我们用的只是scikit-learn给支持向量机的设好的默认参数。支持向量机同样有一大堆可调的超参数,以提升模型的效果。 跟据 Andreas Mueller的这篇博文,调整好超参数的支持向量机能够达到98.5%的准确度!比我们刚才最好的神经网络提高了1.8个百分点!


然而,故事并没有结束。2013年,通过深度神经网络,研究者可以达到99.79%的准确度!而且,他们并没有运用很多高深的技术。很多技术在我们接下来的博文中都可以继续介绍。


所以,从目前的准确度来看:


简单的支持向量机<浅层神经网络<调优的支持向量机<深度神经网络


但还是要提醒一下,炫酷的算法固然重要,但是良好的数据集有时候比算法更重要。Michael Nielsen专门写了一个公式来来表达他们的关系:


精致的算法 ≤ 简单的算法 + 良好的训练数据
sophisticated algorithm ≤ simple learning algorithm + good training data.


所以为了调优模型,往往要溯源到数据本身,好的数据真的会有好的结果。


原文发布时间为:2015-12-29

本文来自云栖社区合作伙伴“大数据文摘”,了解相关信息可以关注“BigDataDigest”微信公众号

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

相关文章
阿里云服务器怎么设置密码?怎么停机?怎么重启服务器?
如果在创建实例时没有设置密码,或者密码丢失,您可以在控制台上重新设置实例的登录密码。本文仅描述如何在 ECS 管理控制台上修改实例登录密码。
9995 0
手把手教你生成对抗网络 GAN,50 行代码玩转 GAN 模型!
本文为大家介绍了生成对抗网络(Generate Adversarial Network,GAN),以最直白的语言来讲解它,最后实现一个简单的 GAN 程序来帮助大家加深理解。
1625 0
阿里云服务器如何登录?阿里云服务器的三种登录方法
购买阿里云ECS云服务器后如何登录?场景不同,阿里云优惠总结大概有三种登录方式: 登录到ECS云服务器控制台 在ECS云服务器控制台用户可以更改密码、更换系.
13792 0
从0开始:500行代码实现 LSM 数据库
LSM-Tree 是很多 NoSQL 数据库引擎的底层实现,例如 LevelDB,Hbase 等。本文基于《数据密集型应用系统设计》中对 LSM-Tree 数据库的设计思路,结合代码实现完整地阐述了一个迷你数据库,核心代码 500 行左右,通过理论结合实践来更好地理解数据库的原理。
10974 0
检查HTTP 的 Basic认证代码示例-JSP
检查HTTP 的 Basic认证. since http1.0 代码如下所示: R U OK? R U OK? . Your Password is 请参考代码中的注释,具体信息,还可以参考《图解HTTP》。
724 0
git统计某一开发者提交代码的增删改动行数和具体详细的改动内容
git统计某一开发者提交代码的增删改动行数和具体详细的改动内容 git命令: echo "统计结果" && git log --author="zhangphil" --after="2018-04-16 00:0...
2596 0
git统计某一名开发者有效代码总行数以及历史删除、增加的总行数
git统计某一名开发者有效代码总行数以及历史删除、增加的总行数 git命令: git log --author="zhangphil" --pretty=tformat: --numstat | gawk '{ add...
2640 0
Spring Boot项目利用MyBatis Generator进行数据层代码自动生成
概 述 MyBatis Generator (简称 MBG) 是一个用于 MyBatis和 iBATIS的代码生成器。它可以为 MyBatis的所有版本以及 2.2.0之后的 iBATIS版本自动生成 ORM层代码,典型地包括我们日常需要手写的 POJO、mapper xml 以及 mapper 接口等。
1840 0
基于Ganos百行代码实现亿级矢量空间数据在线可视化
本文介绍如何使用RDS PG或PolarDB(兼容PG版或Oracle版)的Ganos时空引擎提供的数据库快显技术,仅用百行代码实现亿级海量几何空间数据的在线快速显示和流畅地图交互,且无需关注切片存储和效率问题。
1812 0
2736
文章
6591
问答
来源圈子
更多
+ 订阅
文章排行榜
最热
最新
相关电子书
更多
《2021云上架构与运维峰会演讲合集》
立即下载
《零基础CSS入门教程》
立即下载
《零基础HTML入门教程》
立即下载