Python-Tensorflow-MNIST手写识别

简介: Python-Tensorflow-MNIST手写识别

下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)

image.png

每一张图片包含28*28个像素,我们把这一个数组展开成一个向量,长度是28*28=784。因此在MNIST训练数据集中mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。图片里的某个像素的强度值介于0-1之间。

image.png

image.png

MNIST数据集的标签是介于0-9的数字,我们要把标签转化为“one-hot vectors”。一个one-hot向量除了某一位数字是1以外,其余维度数字都是0,比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]),标签3将表示为([0,0,0,1,0,0,0,0,0,0]) 。

  因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。

image.png

image.png

我们知道MNIST的结果是0-9,我们的模型可能推测出一张图片是数字9的概率是80%,是数字8的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。

image.png

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
#读取mnist数据集 如果没有则会下载
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
#每个批次的大小
batch_size = 50
#计算一共有多少批次
n_batch = mnist.train.num_examples // batch_size
#定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
#创建简单的神经网络
#群值
W = tf.Variable(tf.zeros([784,10]))
#偏置值
b = tf.Variable(tf.zeros([10]))
#预测值
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.3).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#预测数据与样本比较,如果相等就返回1 求出标签
#结果存放在布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#进行训练
with tf.Session() as sess:
    sess.run(init)
    for i in range(101):#周期
        for batch in range(n_batch):#批次
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("周期 :"+ str(i) + "准确率:" +  str(acc))

训练结果:

周期 :0准确率:0.8809
周期 :1准确率:0.8958
周期 :2准确率:0.9035
周期 :3准确率:0.9064
周期 :4准确率:0.9109
周期 :5准确率:0.9122
周期 :6准确率:0.9142
周期 :7准确率:0.9162
周期 :8准确率:0.917
周期 :9准确率:0.9178
周期 :10准确率:0.9188
周期 :11准确率:0.9184
周期 :12准确率:0.9192
周期 :13准确率:0.9197
周期 :14准确率:0.9212
周期 :15准确率:0.9202
周期 :16准确率:0.9218
周期 :17准确率:0.9218
周期 :18准确率:0.922
周期 :19准确率:0.9226
周期 :20准确率:0.9224
周期 :21准确率:0.9232
周期 :22准确率:0.9242
周期 :23准确率:0.924
周期 :24准确率:0.9238
周期 :25准确率:0.9246
周期 :26准确率:0.9247
周期 :27准确率:0.9248
周期 :28准确率:0.9253
周期 :29准确率:0.9251
周期 :30准确率:0.9259
周期 :31准确率:0.9261
周期 :32准确率:0.9259
周期 :33准确率:0.9263
周期 :34准确率:0.9275
周期 :35准确率:0.9268
周期 :36准确率:0.9271
周期 :37准确率:0.9269
周期 :38准确率:0.9279
周期 :39准确率:0.9279
周期 :40准确率:0.9271
目录
相关文章
|
4天前
|
机器学习/深度学习 算法 算法框架/工具
Python 迁移学习实用指南:1~5(2)
Python 迁移学习实用指南:1~5(2)
125 0
|
4天前
|
机器学习/深度学习 人工智能 算法
Python 迁移学习实用指南:1~5(1)
Python 迁移学习实用指南:1~5(1)
119 0
|
4天前
|
机器学习/深度学习 存储 算法
Python 无监督学习实用指南:1~5(3)
Python 无监督学习实用指南:1~5(3)
37 0
Python 无监督学习实用指南:1~5(3)
|
4天前
|
机器学习/深度学习 算法 关系型数据库
Python 无监督学习实用指南:6~10(4)
Python 无监督学习实用指南:6~10(4)
39 0
|
4天前
|
机器学习/深度学习 存储 算法
Python 无监督学习实用指南:1~5(5)
Python 无监督学习实用指南:1~5(5)
42 0
|
4天前
|
机器学习/深度学习 数据可视化 算法框架/工具
Python 迁移学习实用指南:1~5(5)
Python 迁移学习实用指南:1~5(5)
34 0
|
4天前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow与PyTorch在Python面试中的对比与应用
【4月更文挑战第16天】这篇博客探讨了Python面试中TensorFlow和PyTorch的常见问题,包括框架基础操作、自动求梯度与反向传播、数据加载与预处理。易错点包括混淆框架API、动态图与静态图的理解、GPU加速的利用、模型保存恢复以及版本兼容性。通过掌握这些问题和解决策略,面试者能展示其深度学习框架技能。
37 9
|
4天前
|
机器学习/深度学习 数据采集 算法
在Python中,特征提取
在Python中,特征提取
57 4
|
4天前
|
机器学习/深度学习 存储 算法框架/工具
Python 迁移学习实用指南:6~11(2)
Python 迁移学习实用指南:6~11(2)
48 0
|
4天前
|
机器学习/深度学习 数据可视化 算法
Python 迁移学习实用指南:6~11(3)
Python 迁移学习实用指南:6~11(3)
53 0