# 【Python实战】——神经网络识别手写数字（二）

【Python实战】——神经网络识别手写数字（一）+https://developer.aliyun.com/article/1506500

### 2.3 神经网络模型定义

运行程序：

ANN = NeuralNetwork(num_of_in_nodes = image_pixels, #输入
num_of_out_nodes = 10, #输出节点数
num_of_hidden_nodes = 100,#隐藏节点
learning_rate = 0.1)#学习率

### 2.4 模型训练

#### 2.4.1 预测概率

运行程序：

for i in range(len(train_imgs)):
ANN.train(train_imgs[i], train_labels_one_hot[i])
for i in range(20):
res = ANN.run(test_imgs[i])
print(test_labels[i], np.argmax(res), np.max(res))

运行结果：

[7.] 7 0.9992648448921
[2.] 2 0.9040034245332168
[1.] 1 0.9992201001324703
[0.] 0 0.9923701545281887
[4.] 4 0.989297708155559
[1.] 1 0.9984582148795715
[4.] 4 0.9957673752296046
[9.] 9 0.9889417895800644
[5.] 6 0.5009071817613537
[9.] 9 0.9879513019542627
[0.] 0 0.9932950902790246
[6.] 6 0.9387061553685657
[9.] 9 0.9962530965286298
[0.] 0 0.9974524110371016
[1.] 1 0.9991354417269441
[5.] 5 0.7607733657668813
[9.] 9 0.9968080255475414
[7.] 7 0.9967748204232602
[3.] 3 0.8820920415159276
[4.] 4 0.9978584850755227

#### 2.4.2 训练集正确率

运行程序：

corrects, wrongs = ANN.evaluate(train_imgs, train_labels)#训练集判别正确和错误数量
print("accuracy train: ", corrects / ( corrects + wrongs))##正确率

运行结果：

accuracy train:  0.9425333333333333

#### 2.4.3 测试集正确率

运行程序：

corrects, wrongs = ANN.evaluate(test_imgs, test_labels)
print("accuracy: test", corrects / ( corrects + wrongs))#测试集正确率

运行结果：

accuracy: test 0.9412

#### 2.4.4 训练集判别矩阵

运行程序：

cm = ANN.confusion_matrix(train_imgs, train_labels)
print(cm)   #训练集判别矩阵

运行结果：

[[5822    1   54   35   15   41   47   12   31   31]
[   2 6638   62   31   17   24   21   64  163   14]
[   6   19 5487   57   16    9    2   45   16    4]
[   7   27   87 5773    3  130    3   16  148   67]
[  11   11   68    8 5332   34   12   48   28   44]
[  10    4    6   69    0 4952   34    5   32    5]
[  31    5   53   19   49   96 5782    5   37    2]
[   1    9   45   35    6    6    0 5812    5   28]
[  20    9   70   32    9   37   15   11 5209    9]
[  13   19   26   72  395   92    2  247  182 5745]]

#### 2.4.5 不同数字预测精确率

运行程序：

for i in range(10):
print("digit: ", i, "precision: ", ANN.precision(i, cm))

运行结果：

digit:  0 precision:  0.9829478304913051
digit:  1 precision:  0.9845743102936814
digit:  2 precision:  0.9209466263846928
digit:  3 precision:  0.9416082205186755
digit:  4 precision:  0.9127011297500855
digit:  5 precision:  0.9134845969378343
digit:  6 precision:  0.9770192632646164
digit:  7 precision:  0.9276935355147645
digit:  8 precision:  0.8902751666381815
digit:  9 precision:  0.9657085224407463

### 2.5 结果可视化

#### 2.5.1 每次epoch训练预测情况

运行程序：

epochs = 30
train_acc=[]
test_acc=[]
NN = NeuralNetwork(num_of_in_nodes = image_pixels,
num_of_out_nodes = 10,
num_of_hidden_nodes = 100,
learning_rate = 0.1)
for epoch in range(epochs):
print("epoch: ", epoch)
for i in range(len(train_imgs)):
NN.train(train_imgs[i],
train_labels_one_hot[i])

corrects, wrongs = NN.evaluate(train_imgs, train_labels)
print("accuracy train: ", corrects / ( corrects + wrongs))
train_acc.append(corrects / ( corrects + wrongs))
corrects, wrongs = NN.evaluate(test_imgs, test_labels)
print("accuracy: test", corrects / ( corrects + wrongs))
test_acc.append(corrects / ( corrects + wrongs))

epoch:  0
accuracy train:  0.94455
accuracy: test 0.9422
epoch:  1
accuracy train:  0.9628
accuracy: test 0.9579
epoch:  2
accuracy train:  0.9699
accuracy: test 0.9637
epoch:  3
accuracy train:  0.9761166666666666
accuracy: test 0.9649
epoch:  4
accuracy train:  0.979
accuracy: test 0.9662
epoch:  5
accuracy train:  0.9820833333333333
accuracy: test 0.9679
epoch:  6
accuracy train:  0.9838166666666667
accuracy: test 0.9697
epoch:  7
accuracy train:  0.9845666666666667
accuracy: test 0.97
epoch:  8
accuracy train:  0.9855333333333334
accuracy: test 0.9703
epoch:  9
accuracy train:  0.9868166666666667
accuracy: test 0.97
epoch:  10
accuracy train:  0.9878166666666667
accuracy: test 0.9714
epoch:  11
accuracy train:  0.98845
accuracy: test 0.9716
epoch:  12
accuracy train:  0.98905
accuracy: test 0.9721
epoch:  13
accuracy train:  0.9898166666666667
accuracy: test 0.9723
epoch:  14
accuracy train:  0.9903
accuracy: test 0.9722
epoch:  15
accuracy train:  0.9907666666666667
accuracy: test 0.9719
epoch:  16
accuracy train:  0.9910833333333333
accuracy: test 0.9715
epoch:  17
accuracy train:  0.9918
accuracy: test 0.9714
epoch:  18
accuracy train:  0.9924166666666666
accuracy: test 0.971
epoch:  19
accuracy train:  0.99265
accuracy: test 0.9712
epoch:  20
accuracy train:  0.9932833333333333
accuracy: test 0.972
epoch:  21
accuracy train:  0.9939333333333333
accuracy: test 0.9716
epoch:  22
accuracy train:  0.9944333333333333
accuracy: test 0.972
epoch:  23
accuracy train:  0.9948
accuracy: test 0.9719
epoch:  24
accuracy train:  0.9950833333333333
accuracy: test 0.9718
epoch:  25
accuracy train:  0.9950833333333333
accuracy: test 0.9722
epoch:  26
accuracy train:  0.99525
accuracy: test 0.9725
epoch:  27
accuracy train:  0.9955833333333334
accuracy: test 0.972
epoch:  28
accuracy train:  0.9958166666666667
accuracy: test 0.9717
epoch:  29
accuracy train:  0.9962666666666666
accuracy: test 0.9717

#### 2.5.2 迭代30次正确率绘图

运行程序：

#正确率绘图
# matplotlib其实是不支持显示中文的 显示中文需要一行代码设置字体
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['font.family'] = 'SimHei'
plt.rcParams['axes.unicode_minus'] = False
import matplotlib.pyplot as plt
x=np.arange(1,31,1)
plt.title('迭代30次正确率')
plt.plot(x, train_acc, color='green', label='训练集')
plt.plot(x, test_acc, color='red', label='测试集')
plt.legend() # 显示图例
plt.show()

运行结果：

【Python实战】——神经网络识别手写数字（三）+https://developer.aliyun.com/article/1506502

|
23小时前
|
XML 数据库 数据格式
Python网络数据抓取（9）：XPath
Python网络数据抓取（9）：XPath
11 0
|
2天前
|

Python网络爬虫实战：抓取并分析网页数据

53 9
|
2天前
|

25 11
|
3天前
|

Python在金融数据分析中扮演关键角色，用于预测市场趋势和风险管理。本文通过案例展示了使用Python库（如pandas、numpy、matplotlib等）进行数据获取、清洗、分析和建立预测模型，例如计算苹果公司（AAPL）股票的简单移动平均线，以展示基本流程。此示例为更复杂的金融建模奠定了基础。【6月更文挑战第13天】
17 3
|
3天前
|

Python中的并发编程（4）多线程发送网络请求
Python中的并发编程（4）多线程发送网络请求
10 1
|
4天前
|

Python3网络开发实战读后感
Python3网络开发实战读后感
7 2
|
4天前
|

12 0
|
4天前
|

31 0
|
4天前
|

13 2
|
4天前
|

【功能超全】基于OpenCV车牌识别停车场管理系统软件开发【含python源码+PyqtUI界面+功能详解】-车牌识别python 深度学习实战项目
【功能超全】基于OpenCV车牌识别停车场管理系统软件开发【含python源码+PyqtUI界面+功能详解】-车牌识别python 深度学习实战项目
13 0