(1) 引入函数库
import numpy as np import matplotlib.pyplot as plt import os import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.examples.tutorials.mnist import input_data
(2)加载数据
mnist = input_data.read_data_sets("datasets/MNIST_data/", one_hot=True)
(3)定义参数
learning_rate = 0.0001 num_epochs = 10000 BATCH_SIZE = 100 LEARNING_RATE_BASE = 0.8 LEARNING_RATE_DECAY = 0.99 MOVING_AVERAGE_DECCAY = 0.99 REGULARIZER_RATE = 0.0001 MODEL_SAVE_PATH = "MNIST_model/" MODEL_NAME = "mnist_model" (m,n_x) = mnist.train.images.shape #784 n_y = mnist.train.labels.shape[1] #10 n_1 = 500 costs = [] tf.set_random_seed(1)
(4)初始化参数
def init_para(): W1 = tf.get_variable("w1",[n_x,n_1],initializer = tf.contrib.layers.xavier_initializer(seed = 1)) #(784,500) b1 = tf.get_variable("b1",[1,n_1], initializer = tf.zeros_initializer()) #(1,500) W2 = tf.get_variable("w2",[n_1,n_y],initializer = tf.contrib.layers.xavier_initializer(seed = 1)) #(500,10) b2 = tf.get_variable("b2",[1,n_y], initializer = tf.zeros_initializer()) #(1,10) return W1,b1,W2,b2
(5)正向传播
def forward(X, parameters, regularizer, variable_averages): W1,b1,W2,b2 = parameters # 正则化 if regularizer != None: tf.add_to_collection('losses',regularizer(W1)) tf.add_to_collection('losses',regularizer(W2)) #滑动平均 if variable_averages != None: Z1 = tf.nn.relu(tf.matmul(X,variable_averages.average(W1)) + variable_averages.average(b1)) #(55000,500) Z2 = tf.matmul(Z1,variable_averages.average(W2)) + variable_averages.average(b2) #(55000,10) else: Z1 = tf.nn.relu(tf.matmul(X,W1) + b1) #(55000,500) Z2 = tf.matmul(Z1,W2) + b2 #(55000,10) return Z2
(6)模型训练
def train(): X = tf.placeholder(tf.float32, shape=(None,n_x), name="X") #(55000,784) Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y") #(55000,10) prameters = init_para() global_step = tf.Variable(0, trainable = False) #正则化 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZER_RATE) #滑动平均 variable_averages = tf.train.ExponentialMovingAverage(LEARNING_RATE_DECAY,global_step) variable_averages_op = variable_averages.apply(tf.trainable_variables()) Y_ = forward(X, prameters, regularizer, None) Y_avg = forward(X, prameters, None, variable_averages) #交叉熵损失函数 cem = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits = Y_, labels = Y)) cost = cem + tf.add_n(tf.get_collection('losses')) #指数衰减 learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,m/BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True) optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost,global_step = global_step) #optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost,global_step = global_step) #不适用指数衰减 with tf.control_dependencies([optimizer,variable_averages_op]): train_op= tf.no_op(name = 'train') saver = tf.train.Saver() with tf.Session() as sess: tf.initialize_all_variables().run() ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) for i in range(num_epochs): x,y = mnist.train.next_batch(BATCH_SIZE) sess.run(train_op,feed_dict={X:x,Y:y}) if i%500 == 0: cost_v = sess.run(cost,feed_dict={X:x,Y:y}) costs.append(cost_v) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step) print(i,cost_v) # Calculate the correct accuracy correct_prediction = tf.equal(tf.argmax(Y_avg,1), tf.argmax(Y,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print ("Train Accuracy:", accuracy.eval({X:mnist.train.images, Y: mnist.train.labels})) # print ("Test Accuracy:", accuracy.eval({X: mnist.test.images, Y: mnist.test.labels})) plt.plot(np.squeeze(costs)) plt.ylabel('cost') plt.xlabel('iterations (per tens)') plt.title("Learning rate =" + str(learning_rate)) plt.show()
(7)模型评估
def evaluate(mnist): with tf.Graph().as_default() as g: X = tf.placeholder(tf.float32, shape=(None,n_x), name="X") #(55000,784) Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y") #(55000,10) test_feed = {X: mnist.test.images, Y: mnist.test.labels} prameters = init_para() Y_ = forward(X, prameters, None, None) correct_prediction = tf.equal(tf.argmax(Y_,1), tf.argmax(Y,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECCAY) variable_averages_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_averages_restore) with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_feed = sess.run(accuracy, feed_dict = test_feed) print("After %s training steps, valadation accuracy = %g" %(global_step,accuracy_feed)) else: print("No checkpoint file found")
(8)主程序
if __name__ =='__main__': ops.reset_default_graph() train() evaluate(mnist)