TensorFlow MNIST手写数字识别(神经网络最佳实践版)

简介: TensorFlow MNIST手写数字识别(神经网络最佳实践版)

(1) 引入函数库

import numpy as np
import matplotlib.pyplot as plt
import os 
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.examples.tutorials.mnist import input_data

(2)加载数据


mnist = input_data.read_data_sets("datasets/MNIST_data/", one_hot=True)


(3)定义参数

learning_rate = 0.0001
num_epochs = 10000
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
MOVING_AVERAGE_DECCAY = 0.99
REGULARIZER_RATE = 0.0001
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"
(m,n_x) = mnist.train.images.shape #784
n_y = mnist.train.labels.shape[1] #10
n_1 = 500
costs = []
tf.set_random_seed(1)

(4)初始化参数

def init_para():
    W1 = tf.get_variable("w1",[n_x,n_1],initializer = tf.contrib.layers.xavier_initializer(seed = 1))   #(784,500)
    b1 = tf.get_variable("b1",[1,n_1],  initializer = tf.zeros_initializer())                          #(1,500)
    W2 = tf.get_variable("w2",[n_1,n_y],initializer = tf.contrib.layers.xavier_initializer(seed = 1))  #(500,10)
    b2 = tf.get_variable("b2",[1,n_y],  initializer = tf.zeros_initializer())                          #(1,10)
    return W1,b1,W2,b2

(5)正向传播

def forward(X, parameters, regularizer, variable_averages):
    W1,b1,W2,b2 = parameters
    # 正则化
    if regularizer != None:
        tf.add_to_collection('losses',regularizer(W1))
        tf.add_to_collection('losses',regularizer(W2))
    #滑动平均
    if variable_averages != None:
        Z1 = tf.nn.relu(tf.matmul(X,variable_averages.average(W1)) + variable_averages.average(b1)) #(55000,500)
        Z2 = tf.matmul(Z1,variable_averages.average(W2)) + variable_averages.average(b2)            #(55000,10)
    else:
        Z1 = tf.nn.relu(tf.matmul(X,W1) + b1) #(55000,500)
        Z2 = tf.matmul(Z1,W2) + b2            #(55000,10)
    return Z2

(6)模型训练

def train():
    X = tf.placeholder(tf.float32, shape=(None,n_x), name="X")  #(55000,784)
    Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y")  #(55000,10)
    prameters = init_para()
    global_step = tf.Variable(0, trainable = False)
    #正则化
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZER_RATE) 
    #滑动平均
    variable_averages = tf.train.ExponentialMovingAverage(LEARNING_RATE_DECAY,global_step)
    variable_averages_op = variable_averages.apply(tf.trainable_variables())
    Y_ = forward(X, prameters, regularizer, None)
    Y_avg = forward(X, prameters, None, variable_averages)
    #交叉熵损失函数
    cem = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits = Y_, labels = Y))
    cost = cem + tf.add_n(tf.get_collection('losses'))
    #指数衰减
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,m/BATCH_SIZE,
                    LEARNING_RATE_DECAY,staircase=True)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost,global_step = global_step)
    #optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost,global_step = global_step) #不适用指数衰减
    with tf.control_dependencies([optimizer,variable_averages_op]):
        train_op= tf.no_op(name = 'train')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)        
        if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)
        for i in range(num_epochs):
            x,y = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op,feed_dict={X:x,Y:y})
            if i%500 == 0:
                cost_v = sess.run(cost,feed_dict={X:x,Y:y})
                costs.append(cost_v)
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step)
                print(i,cost_v)
       # Calculate the correct accuracy
        correct_prediction = tf.equal(tf.argmax(Y_avg,1), tf.argmax(Y,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        print ("Train Accuracy:", accuracy.eval({X:mnist.train.images, Y: mnist.train.labels}))  #
        print ("Test Accuracy:", accuracy.eval({X: mnist.test.images, Y: mnist.test.labels}))
    plt.plot(np.squeeze(costs))
    plt.ylabel('cost')
    plt.xlabel('iterations (per tens)')
    plt.title("Learning rate =" + str(learning_rate))
    plt.show()

(7)模型评估

def evaluate(mnist):
    with tf.Graph().as_default() as g:
        X = tf.placeholder(tf.float32, shape=(None,n_x), name="X")  #(55000,784)
        Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y")  #(55000,10)
        test_feed = {X: mnist.test.images, Y: mnist.test.labels}
        prameters = init_para()
        Y_ = forward(X, prameters, None, None)
        correct_prediction = tf.equal(tf.argmax(Y_,1), tf.argmax(Y,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECCAY)
        variable_averages_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variable_averages_restore)
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                accuracy_feed = sess.run(accuracy, feed_dict = test_feed)
                print("After %s training steps, valadation accuracy = %g" %(global_step,accuracy_feed))
            else:
                print("No checkpoint file found")


(8)主程序

if __name__ =='__main__':
    ops.reset_default_graph()   
    train()
    evaluate(mnist)


相关文章
|
2月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
1月前
|
机器学习/深度学习 算法 数据库
基于CNN卷积网络的MNIST手写数字识别matlab仿真,CNN编程实现不使用matlab工具箱
基于CNN卷积网络的MNIST手写数字识别matlab仿真,CNN编程实现不使用matlab工具箱
|
3月前
|
机器学习/深度学习 算法 TensorFlow
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
63 0
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
|
9天前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
29 0
|
2月前
|
机器学习/深度学习 人工智能 API
人工智能应用工程师技能提升系列2、——TensorFlow2——keras高级API训练神经网络模型
人工智能应用工程师技能提升系列2、——TensorFlow2——keras高级API训练神经网络模型
34 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
|
2月前
|
机器学习/深度学习 算法
基于BP神经网络的手写体数字识别matlab仿真
基于BP神经网络的手写体数字识别matlab仿真
|
3月前
|
机器学习/深度学习 存储 算法
TensorFlow 卷积神经网络实用指南:6~10
TensorFlow 卷积神经网络实用指南:6~10
105 0
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow 卷积神经网络实用指南:1~5
TensorFlow 卷积神经网络实用指南:1~5
29 0
|
3月前
|
机器学习/深度学习 人工智能 算法
鱼类识别Python+深度学习人工智能+TensorFlow+卷积神经网络算法
鱼类识别Python+深度学习人工智能+TensorFlow+卷积神经网络算法
46 2
鱼类识别Python+深度学习人工智能+TensorFlow+卷积神经网络算法