【Python实战】——神经网络识别手写数字(二)

简介: 【Python实战】——神经网络识别手写数字(三)

【Python实战】——神经网络识别手写数字(一)+https://developer.aliyun.com/article/1506500

2.3 神经网络模型定义

  运行程序:

ANN = NeuralNetwork(num_of_in_nodes = image_pixels, #输入
                    num_of_out_nodes = 10, #输出节点数
                    num_of_hidden_nodes = 100,#隐藏节点
                    learning_rate = 0.1)#学习率

2.4 模型训练

2.4.1 预测概率

  运行程序:

for i in range(len(train_imgs)):
    ANN.train(train_imgs[i], train_labels_one_hot[i])
for i in range(20):
    res = ANN.run(test_imgs[i])
    print(test_labels[i], np.argmax(res), np.max(res))

  运行结果:

[7.] 7 0.9992648448921
[2.] 2 0.9040034245332168
[1.] 1 0.9992201001324703
[0.] 0 0.9923701545281887
[4.] 4 0.989297708155559
[1.] 1 0.9984582148795715
[4.] 4 0.9957673752296046
[9.] 9 0.9889417895800644
[5.] 6 0.5009071817613537
[9.] 9 0.9879513019542627
[0.] 0 0.9932950902790246
[6.] 6 0.9387061553685657
[9.] 9 0.9962530965286298
[0.] 0 0.9974524110371016
[1.] 1 0.9991354417269441
[5.] 5 0.7607733657668813
[9.] 9 0.9968080255475414
[7.] 7 0.9967748204232602
[3.] 3 0.8820920415159276
[4.] 4 0.9978584850755227

2.4.2 训练集正确率

  运行程序:

corrects, wrongs = ANN.evaluate(train_imgs, train_labels)#训练集判别正确和错误数量
print("accuracy train: ", corrects / ( corrects + wrongs))##正确率

  运行结果:

accuracy train:  0.9425333333333333

2.4.3 测试集正确率

  运行程序:

corrects, wrongs = ANN.evaluate(test_imgs, test_labels)
print("accuracy: test", corrects / ( corrects + wrongs))#测试集正确率

  运行结果:

accuracy: test 0.9412

2.4.4 训练集判别矩阵

  运行程序:

cm = ANN.confusion_matrix(train_imgs, train_labels)
print(cm)   #训练集判别矩阵

  运行结果:

[[5822    1   54   35   15   41   47   12   31   31]
 [   2 6638   62   31   17   24   21   64  163   14]
 [   6   19 5487   57   16    9    2   45   16    4]
 [   7   27   87 5773    3  130    3   16  148   67]
 [  11   11   68    8 5332   34   12   48   28   44]
 [  10    4    6   69    0 4952   34    5   32    5]
 [  31    5   53   19   49   96 5782    5   37    2]
 [   1    9   45   35    6    6    0 5812    5   28]
 [  20    9   70   32    9   37   15   11 5209    9]
 [  13   19   26   72  395   92    2  247  182 5745]]

2.4.5 不同数字预测精确率

  运行程序:

for i in range(10):
    print("digit: ", i, "precision: ", ANN.precision(i, cm))

  运行结果:

digit:  0 precision:  0.9829478304913051
digit:  1 precision:  0.9845743102936814
digit:  2 precision:  0.9209466263846928
digit:  3 precision:  0.9416082205186755
digit:  4 precision:  0.9127011297500855
digit:  5 precision:  0.9134845969378343
digit:  6 precision:  0.9770192632646164
digit:  7 precision:  0.9276935355147645
digit:  8 precision:  0.8902751666381815
digit:  9 precision:  0.9657085224407463

2.5 结果可视化

2.5.1 每次epoch训练预测情况

  运行程序:

epochs = 30
train_acc=[]
test_acc=[]
NN = NeuralNetwork(num_of_in_nodes = image_pixels, 
                   num_of_out_nodes = 10, 
                   num_of_hidden_nodes = 100,
                   learning_rate = 0.1)
for epoch in range(epochs):  
    print("epoch: ", epoch)
    for i in range(len(train_imgs)):
        NN.train(train_imgs[i], 
                 train_labels_one_hot[i])
  
    corrects, wrongs = NN.evaluate(train_imgs, train_labels)
    print("accuracy train: ", corrects / ( corrects + wrongs))
    train_acc.append(corrects / ( corrects + wrongs))
    corrects, wrongs = NN.evaluate(test_imgs, test_labels)
    print("accuracy: test", corrects / ( corrects + wrongs))
    test_acc.append(corrects / ( corrects + wrongs))

运行结果:

epoch:  0
accuracy train:  0.94455
accuracy: test 0.9422
epoch:  1
accuracy train:  0.9628
accuracy: test 0.9579
epoch:  2
accuracy train:  0.9699
accuracy: test 0.9637
epoch:  3
accuracy train:  0.9761166666666666
accuracy: test 0.9649
epoch:  4
accuracy train:  0.979
accuracy: test 0.9662
epoch:  5
accuracy train:  0.9820833333333333
accuracy: test 0.9679
epoch:  6
accuracy train:  0.9838166666666667
accuracy: test 0.9697
epoch:  7
accuracy train:  0.9845666666666667
accuracy: test 0.97
epoch:  8
accuracy train:  0.9855333333333334
accuracy: test 0.9703
epoch:  9
accuracy train:  0.9868166666666667
accuracy: test 0.97
epoch:  10
accuracy train:  0.9878166666666667
accuracy: test 0.9714
epoch:  11
accuracy train:  0.98845
accuracy: test 0.9716
epoch:  12
accuracy train:  0.98905
accuracy: test 0.9721
epoch:  13
accuracy train:  0.9898166666666667
accuracy: test 0.9723
epoch:  14
accuracy train:  0.9903
accuracy: test 0.9722
epoch:  15
accuracy train:  0.9907666666666667
accuracy: test 0.9719
epoch:  16
accuracy train:  0.9910833333333333
accuracy: test 0.9715
epoch:  17
accuracy train:  0.9918
accuracy: test 0.9714
epoch:  18
accuracy train:  0.9924166666666666
accuracy: test 0.971
epoch:  19
accuracy train:  0.99265
accuracy: test 0.9712
epoch:  20
accuracy train:  0.9932833333333333
accuracy: test 0.972
epoch:  21
accuracy train:  0.9939333333333333
accuracy: test 0.9716
epoch:  22
accuracy train:  0.9944333333333333
accuracy: test 0.972
epoch:  23
accuracy train:  0.9948
accuracy: test 0.9719
epoch:  24
accuracy train:  0.9950833333333333
accuracy: test 0.9718
epoch:  25
accuracy train:  0.9950833333333333
accuracy: test 0.9722
epoch:  26
accuracy train:  0.99525
accuracy: test 0.9725
epoch:  27
accuracy train:  0.9955833333333334
accuracy: test 0.972
epoch:  28
accuracy train:  0.9958166666666667
accuracy: test 0.9717
epoch:  29
accuracy train:  0.9962666666666666
accuracy: test 0.9717

2.5.2 迭代30次正确率绘图

  运行程序:

#正确率绘图
# matplotlib其实是不支持显示中文的 显示中文需要一行代码设置字体  
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['font.family'] = 'SimHei'  
plt.rcParams['axes.unicode_minus'] = False   
import matplotlib.pyplot as plt 
x=np.arange(1,31,1)
plt.title('迭代30次正确率')
plt.plot(x, train_acc, color='green', label='训练集')
plt.plot(x, test_acc, color='red', label='测试集')
plt.legend() # 显示图例
plt.show()

  运行结果:

【Python实战】——神经网络识别手写数字(三)+https://developer.aliyun.com/article/1506502

相关文章
|
3天前
|
机器学习/深度学习 人工智能 算法
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
文本分类识别系统。本系统使用Python作为主要开发语言,首先收集了10种中文文本数据集("体育类", "财经类", "房产类", "家居类", "教育类", "科技类", "时尚类", "时政类", "游戏类", "娱乐类"),然后基于TensorFlow搭建CNN卷积神经网络算法模型。通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型,并保存为本地的h5格式。然后使用Django开发Web网页端操作界面,实现用户上传一段文本识别其所属的类别。
15 1
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
|
2天前
|
数据采集 存储 JavaScript
构建您的第一个Python网络爬虫:抓取、解析与存储数据
【9月更文挑战第24天】在数字时代,数据是新的金矿。本文将引导您使用Python编写一个简单的网络爬虫,从互联网上自动抓取信息。我们将介绍如何使用requests库获取网页内容,BeautifulSoup进行HTML解析,以及如何将数据存储到文件或数据库中。无论您是数据分析师、研究人员还是对编程感兴趣的新手,这篇文章都将为您提供一个实用的入门指南。拿起键盘,让我们开始挖掘互联网的宝藏吧!
|
1天前
|
数据采集 人工智能 数据挖掘
Python编程入门:从基础到实战的快速指南
【9月更文挑战第25天】本文旨在为初学者提供一个简明扼要的Python编程入门指南。通过介绍Python的基本概念、语法规则以及实际案例分析,帮助读者迅速掌握Python编程的核心技能。文章将避免使用复杂的专业术语,而是采用通俗易懂的语言和直观的例子来阐述概念,确保内容的可读性和实用性。
|
2天前
|
机器学习/深度学习 存储 数据挖掘
Python编程入门:从基础到实战
【9月更文挑战第24天】本教程将引领初学者步入Python编程的奇妙世界。我们将从最基础的概念开始,逐步深入,通过实例和练习,让你掌握这门强大而易学的语言。无论你是编程新手,还是希望扩展技能的开发者,这篇文章都将为你开启一段充满乐趣的编程之旅。
13 2
|
3天前
|
存储 前端开发 API
告别繁琐,拥抱简洁!Python RESTful API 设计实战,让 API 调用如丝般顺滑!
在 Web 开发的旅程中,设计一个高效、简洁且易于使用的 RESTful API 是至关重要的。今天,我想和大家分享一次我在 Python 中进行 RESTful API 设计的实战经历,希望能给大家带来一些启发。
13 3
|
4天前
|
算法 搜索推荐 开发者
别再让复杂度拖你后腿!Python 算法设计与分析实战,教你如何精准评估与优化!
在 Python 编程中,算法的性能至关重要。本文将带您深入了解算法复杂度的概念,包括时间复杂度和空间复杂度。通过具体的例子,如冒泡排序算法 (`O(n^2)` 时间复杂度,`O(1)` 空间复杂度),我们将展示如何评估算法的性能。同时,我们还会介绍如何优化算法,例如使用 Python 的内置函数 `max` 来提高查找最大值的效率,或利用哈希表将查找时间从 `O(n)` 降至 `O(1)`。此外,还将介绍使用 `timeit` 模块等工具来评估算法性能的方法。通过不断实践,您将能更高效地优化 Python 程序。
17 4
|
2天前
|
缓存 中间件 网络架构
Python Web开发实战:高效利用路由与中间件提升应用性能
在Python Web开发中,路由和中间件是构建高效、可扩展应用的核心组件。路由通过装饰器如`@app.route()`将HTTP请求映射到处理函数;中间件则在请求处理流程中插入自定义逻辑,如日志记录和验证。合理设计路由和中间件能显著提升应用性能和可维护性。本文以Flask为例,详细介绍如何优化路由、避免冲突、使用蓝图管理大型应用,并通过中间件实现缓存、请求验证及异常处理等功能,帮助你构建快速且健壮的Web应用。
8 1
|
2天前
|
数据可视化 数据挖掘 Linux
震撼发布!Python数据分析师必学,Matplotlib与Seaborn数据可视化实战全攻略!
在数据科学领域,数据可视化是连接数据与洞察的桥梁,能让复杂的关系变得直观。本文通过实战案例,介绍Python数据分析师必备的Matplotlib与Seaborn两大可视化工具。首先,通过Matplotlib绘制基本折线图;接着,使用Seaborn绘制统计分布图;最后,结合两者在同一图表中展示数据分布与趋势,帮助你提升数据可视化技能,更好地讲述数据故事。
11 1
|
3天前
|
机器学习/深度学习 人工智能 算法
【果蔬识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
【果蔬识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台。果蔬识别系统,本系统使用Python作为主要开发语言,通过收集了12种常见的水果和蔬菜('土豆', '圣女果', '大白菜', '大葱', '梨', '胡萝卜', '芒果', '苹果', '西红柿', '韭菜', '香蕉', '黄瓜'),然后基于TensorFlow库搭建CNN卷积神经网络算法模型,然后对数据集进行训练,最后得到一个识别精度较高的算法模型,然后将其保存为h5格式的本地文件方便后期调用。再使用Django框架搭建Web网页平台操作界面,实现用户上传一张果蔬图片识别其名称。
14 0
【果蔬识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
|
2天前
|
调度 Python
python3 协程实战(python3经典编程案例)
该文章通过多个实战案例介绍了如何在Python3中使用协程来提高I/O密集型应用的性能,利用asyncio库以及async/await语法来编写高效的异步代码。
9 0