下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)
每一张图片包含28*28个像素,我们把这一个数组展开成一个向量,长度是28*28=784。因此在MNIST训练数据集中mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。图片里的某个像素的强度值介于0-1之间。
MNIST数据集的标签是介于0-9的数字,我们要把标签转化为“one-hot vectors”。一个one-hot向量除了某一位数字是1以外,其余维度数字都是0,比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]),标签3将表示为([0,0,0,1,0,0,0,0,0,0]) 。
因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。
我们知道MNIST的结果是0-9,我们的模型可能推测出一张图片是数字9的概率是80%,是数字8的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_data #读取mnist数据集 如果没有则会下载 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) #每个批次的大小 batch_size = 50 #计算一共有多少批次 n_batch = mnist.train.num_examples // batch_size #定义两个占位符 x = tf.placeholder(tf.float32,[None,784]) y = tf.placeholder(tf.float32,[None,10]) #创建简单的神经网络 #群值 W = tf.Variable(tf.zeros([784,10])) #偏置值 b = tf.Variable(tf.zeros([10])) #预测值 prediction = tf.nn.softmax(tf.matmul(x,W)+b) #二次代价函数 loss = tf.reduce_mean(tf.square(y-prediction)) #使用梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.3).minimize(loss) #初始化变量 init = tf.global_variables_initializer() #预测数据与样本比较,如果相等就返回1 求出标签 #结果存放在布尔型列表中 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置 #求准确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #进行训练 with tf.Session() as sess: sess.run(init) for i in range(101):#周期 for batch in range(n_batch):#批次 batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print("周期 :"+ str(i) + "准确率:" + str(acc))
训练结果:
周期 :0准确率:0.8809 周期 :1准确率:0.8958 周期 :2准确率:0.9035 周期 :3准确率:0.9064 周期 :4准确率:0.9109 周期 :5准确率:0.9122 周期 :6准确率:0.9142 周期 :7准确率:0.9162 周期 :8准确率:0.917 周期 :9准确率:0.9178 周期 :10准确率:0.9188 周期 :11准确率:0.9184 周期 :12准确率:0.9192 周期 :13准确率:0.9197 周期 :14准确率:0.9212 周期 :15准确率:0.9202 周期 :16准确率:0.9218 周期 :17准确率:0.9218 周期 :18准确率:0.922 周期 :19准确率:0.9226 周期 :20准确率:0.9224 周期 :21准确率:0.9232 周期 :22准确率:0.9242 周期 :23准确率:0.924 周期 :24准确率:0.9238 周期 :25准确率:0.9246 周期 :26准确率:0.9247 周期 :27准确率:0.9248 周期 :28准确率:0.9253 周期 :29准确率:0.9251 周期 :30准确率:0.9259 周期 :31准确率:0.9261 周期 :32准确率:0.9259 周期 :33准确率:0.9263 周期 :34准确率:0.9275 周期 :35准确率:0.9268 周期 :36准确率:0.9271 周期 :37准确率:0.9269 周期 :38准确率:0.9279 周期 :39准确率:0.9279 周期 :40准确率:0.9271