TensorFlow MNIST手写数字识别(神经网络最佳实践版)

简介: TensorFlow MNIST手写数字识别(神经网络最佳实践版)

(1) 引入函数库

import numpy as np
import matplotlib.pyplot as plt
import os 
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.examples.tutorials.mnist import input_data

(2)加载数据


mnist = input_data.read_data_sets("datasets/MNIST_data/", one_hot=True)


(3)定义参数

learning_rate = 0.0001
num_epochs = 10000
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
MOVING_AVERAGE_DECCAY = 0.99
REGULARIZER_RATE = 0.0001
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"
(m,n_x) = mnist.train.images.shape #784
n_y = mnist.train.labels.shape[1] #10
n_1 = 500
costs = []
tf.set_random_seed(1)

(4)初始化参数

def init_para():
    W1 = tf.get_variable("w1",[n_x,n_1],initializer = tf.contrib.layers.xavier_initializer(seed = 1))   #(784,500)
    b1 = tf.get_variable("b1",[1,n_1],  initializer = tf.zeros_initializer())                          #(1,500)
    W2 = tf.get_variable("w2",[n_1,n_y],initializer = tf.contrib.layers.xavier_initializer(seed = 1))  #(500,10)
    b2 = tf.get_variable("b2",[1,n_y],  initializer = tf.zeros_initializer())                          #(1,10)
    return W1,b1,W2,b2

(5)正向传播

def forward(X, parameters, regularizer, variable_averages):
    W1,b1,W2,b2 = parameters
    # 正则化
    if regularizer != None:
        tf.add_to_collection('losses',regularizer(W1))
        tf.add_to_collection('losses',regularizer(W2))
    #滑动平均
    if variable_averages != None:
        Z1 = tf.nn.relu(tf.matmul(X,variable_averages.average(W1)) + variable_averages.average(b1)) #(55000,500)
        Z2 = tf.matmul(Z1,variable_averages.average(W2)) + variable_averages.average(b2)            #(55000,10)
    else:
        Z1 = tf.nn.relu(tf.matmul(X,W1) + b1) #(55000,500)
        Z2 = tf.matmul(Z1,W2) + b2            #(55000,10)
    return Z2

(6)模型训练

def train():
    X = tf.placeholder(tf.float32, shape=(None,n_x), name="X")  #(55000,784)
    Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y")  #(55000,10)
    prameters = init_para()
    global_step = tf.Variable(0, trainable = False)
    #正则化
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZER_RATE) 
    #滑动平均
    variable_averages = tf.train.ExponentialMovingAverage(LEARNING_RATE_DECAY,global_step)
    variable_averages_op = variable_averages.apply(tf.trainable_variables())
    Y_ = forward(X, prameters, regularizer, None)
    Y_avg = forward(X, prameters, None, variable_averages)
    #交叉熵损失函数
    cem = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits = Y_, labels = Y))
    cost = cem + tf.add_n(tf.get_collection('losses'))
    #指数衰减
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,m/BATCH_SIZE,
                    LEARNING_RATE_DECAY,staircase=True)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost,global_step = global_step)
    #optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost,global_step = global_step) #不适用指数衰减
    with tf.control_dependencies([optimizer,variable_averages_op]):
        train_op= tf.no_op(name = 'train')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)        
        if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)
        for i in range(num_epochs):
            x,y = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op,feed_dict={X:x,Y:y})
            if i%500 == 0:
                cost_v = sess.run(cost,feed_dict={X:x,Y:y})
                costs.append(cost_v)
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step)
                print(i,cost_v)
       # Calculate the correct accuracy
        correct_prediction = tf.equal(tf.argmax(Y_avg,1), tf.argmax(Y,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        print ("Train Accuracy:", accuracy.eval({X:mnist.train.images, Y: mnist.train.labels}))  #
        print ("Test Accuracy:", accuracy.eval({X: mnist.test.images, Y: mnist.test.labels}))
    plt.plot(np.squeeze(costs))
    plt.ylabel('cost')
    plt.xlabel('iterations (per tens)')
    plt.title("Learning rate =" + str(learning_rate))
    plt.show()

(7)模型评估

def evaluate(mnist):
    with tf.Graph().as_default() as g:
        X = tf.placeholder(tf.float32, shape=(None,n_x), name="X")  #(55000,784)
        Y = tf.placeholder(tf.float32, shape=(None,n_y), name="Y")  #(55000,10)
        test_feed = {X: mnist.test.images, Y: mnist.test.labels}
        prameters = init_para()
        Y_ = forward(X, prameters, None, None)
        correct_prediction = tf.equal(tf.argmax(Y_,1), tf.argmax(Y,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECCAY)
        variable_averages_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variable_averages_restore)
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                accuracy_feed = sess.run(accuracy, feed_dict = test_feed)
                print("After %s training steps, valadation accuracy = %g" %(global_step,accuracy_feed))
            else:
                print("No checkpoint file found")


(8)主程序

if __name__ =='__main__':
    ops.reset_default_graph()   
    train()
    evaluate(mnist)


相关文章
|
17天前
|
容灾 网络协议 数据库
云卓越架构:云上网络稳定性建设和应用稳定性治理最佳实践
本文介绍了云上网络稳定性体系建设的关键内容,包括面向失败的架构设计、可观测性与应急恢复、客户案例及阿里巴巴的核心电商架构演进。首先强调了网络稳定性的挑战及其应对策略,如责任共担模型和冗余设计。接着详细探讨了多可用区部署、弹性架构规划及跨地域容灾设计的最佳实践,特别是阿里云的产品和技术如何助力实现高可用性和快速故障恢复。最后通过具体案例展示了秒级故障转移的效果,以及同城多活架构下的实际应用。这些措施共同确保了业务在面对网络故障时的持续稳定运行。
|
6月前
|
缓存 数据安全/隐私保护 Kotlin
Kotlin 中的网络请求代理设置最佳实践
Kotlin 中的网络请求代理设置最佳实践
|
4月前
|
数据采集 存储 监控
网络爬虫的最佳实践:结合 set_time_limit() 与 setTrafficLimit() 抓取云盘数据
本文探讨了如何利用 PHP 的 `set_time_limit()` 与爬虫工具的 `setTrafficLimit()` 方法,结合多线程和代理 IP 技术,高效稳定地抓取百度云盘的公开资源。通过设置脚本执行时间和流量限制,使用多线程提高抓取效率,并通过代理 IP 防止 IP 封禁,确保长时间稳定运行。文章还提供了示例代码,展示了如何具体实现这一过程,并加入了数据分类统计功能以监控抓取效果。
81 16
网络爬虫的最佳实践:结合 set_time_limit() 与 setTrafficLimit() 抓取云盘数据
|
3月前
|
安全 物联网 物联网安全
探索未来网络:物联网安全的最佳实践
随着物联网设备的普及,我们的世界变得越来越互联。然而,这也带来了新的安全挑战。本文将探讨在设计、实施和维护物联网系统时,如何遵循一些最佳实践来确保其安全性。通过深入分析各种案例和策略,我们将揭示如何保护物联网设备免受潜在威胁,同时保持其高效运行。
92 5
|
4月前
|
机器学习/深度学习 安全 物联网安全
探索未来网络:物联网安全的最佳实践与创新策略
本文旨在深入探讨物联网(IoT)的安全性问题,分析其面临的主要威胁与挑战,并提出一系列创新性的解决策略。通过技术解析、案例研究与前瞻展望,本文不仅揭示了物联网安全的复杂性,还展示了如何通过综合手段提升设备、数据及网络的安全性。我们强调了跨学科合作的重要性,以及在快速发展的技术环境中保持敏捷与适应性的必要性,为业界和研究者提供了宝贵的参考与启示。
|
5月前
|
SQL 安全 API
数字堡垒之下:网络安全漏洞、加密技术与安全意识的博弈探索RESTful API设计的最佳实践
【8月更文挑战第27天】在数字化浪潮中,网络安全成为守护个人隐私与企业资产的关键防线。本文深入探讨了网络漏洞的成因与影响,分析了加密技术如何为数据保驾护航,并强调了提升公众的安全意识对于构建坚固的信息防御系统的重要性。文章旨在为读者提供一场思维的盛宴,启发更多关于如何在日益复杂的网络世界中保护自己的思考。
|
5月前
|
监控 安全 网络安全
保护网络免受 DDoS 攻击的最佳实践
【8月更文挑战第24天】
110 1
|
4月前
|
存储 安全 物联网
探索未来网络:物联网安全的最佳实践与挑战
在数字化浪潮中,物联网作为连接万物的关键技术,已深刻改变我们的工作与生活方式。然而,随着其应用的广泛化,安全问题日益凸显,成为制约物联网发展的重要瓶颈。本文旨在深入探讨物联网的安全架构、风险点及应对策略,通过分析当前技术趋势和实际案例,提出一套切实可行的安全防护方案,以促进物联网技术的健康发展。
|
5月前
|
机器学习/深度学习 API 算法框架/工具
【Tensorflow+keras】Keras API三种搭建神经网络的方式及以mnist举例实现
使用Keras API构建神经网络的三种方法:使用Sequential模型、使用函数式API以及通过继承Model类来自定义模型,并提供了基于MNIST数据集的示例代码。
71 12
|
5月前
|
开发者 图形学 API
从零起步,深度揭秘:运用Unity引擎及网络编程技术,一步步搭建属于你的实时多人在线对战游戏平台——详尽指南与实战代码解析,带你轻松掌握网络化游戏开发的核心要领与最佳实践路径
【8月更文挑战第31天】构建实时多人对战平台是技术与创意的结合。本文使用成熟的Unity游戏开发引擎,从零开始指导读者搭建简单的实时对战平台。内容涵盖网络架构设计、Unity网络API应用及客户端与服务器通信。首先,创建新项目并选择适合多人游戏的模板,使用推荐的网络传输层。接着,定义基本玩法,如2D多人射击游戏,创建角色预制件并添加Rigidbody2D组件。然后,引入网络身份组件以同步对象状态。通过示例代码展示玩家控制逻辑,包括移动和发射子弹功能。最后,设置服务器端逻辑,处理客户端连接和断开。本文帮助读者掌握构建Unity多人对战平台的核心知识,为进一步开发打下基础。
182 0