TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类

简介: TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类

设计思路

网络异常,图片无法展示
|

实现代码

 

# -*- coding:utf-8 -*-

import tensorflow as tf

import numpy as np

from tensorflow.contrib import rnn

from tensorflow.examples.tutorials.mnist import input_data

#根据电脑情况设置 GPU

config = tf.ConfigProto()

config.gpu_options.allow_growth = True

sess = tf.Session(config=config)

# 1、定义数据集

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

print(mnist.train.images.shape)

#2、定义模型超参数

lr = 1e-3

# batch_size = 128

batch_size = tf.placeholder(tf.int32)  #采用占位符的方式,因为在训练和测试的时候要用不同的batch_size。注意类型必须为 tf.int32

input_size = 28      # 每个时刻的输入特征是28维的,就是每个时刻输入一行,一行有 28 个像素

timestep_size = 28   # 时序持续长度为28,即每做一次预测,需要先输入28行

hidden_size = 256    # 每个隐含层的节点数

layer_num = 2        # LSTM layer 的层数

class_num = 10       # 最后输出分类类别数量,如果是回归预测的话应该是 1

_X = tf.placeholder(tf.float32, [None, 784])

y = tf.placeholder(tf.float32, [None, class_num])

keep_prob = tf.placeholder(tf.float32)

#3、LSTM模型的搭建、训练、测试

#3.1、LSTM模型的搭建

X = tf.reshape(_X, [-1, 28, 28])  #RNN 的输入shape = (batch_size, timestep_size, input_size),把784个点的字符信息还原成 28 * 28 的图片

lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)      #定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度

lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob) #添加 dropout layer, 一般只设置 output_keep_prob

mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)  #调用 MultiRNNCell来实现多层 LSTM

init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)             #用全零来初始化state

#3.2、LSTM模型的运行:构建好的网络运行起来

#T1、调用 dynamic_rnn()法

# ** 当 time_major==False 时, outputs.shape = [batch_size, timestep_size, hidden_size],所以,可以取 h_state = outputs[:, -1, :] 作为最后输出

# ** state.shape = [layer_num, 2, batch_size, hidden_size],或者,可以取 h_state = state[-1][1] 作为最后输出,最后输出维度是 [batch_size, hidden_size]

# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)

# h_state = outputs[:, -1, :]  # 或者 h_state = state[-1][1]

#T2、自定义LSTM迭代按时间步展开计算:为了更好的理解 LSTM 工作原理把T1的函数自己来实现

#(1)、可以采用RNNCell的 __call__()函数,来实现LSTM按时间步迭代。

outputs = list()

state = init_state

with tf.variable_scope('RNN'):

   for timestep in range(timestep_size):

       if timestep > 0:

           tf.get_variable_scope().reuse_variables()

       (cell_output, state) = mlstm_cell(X[:, timestep, :], state)   # 这里的state保存了每一层 LSTM 的状态

       outputs.append(cell_output)

h_state = outputs[-1]

#3.3、LSTM模型的训练

# 定义 softmax 的连接权重矩阵和偏置:上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor,我们要分类的话,还需要接一个 softmax 层

# out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights')

# out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias')

W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)

bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)

y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)

#定义损失和评估函数

cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))

train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

sess.run(tf.global_variables_initializer())

for i in range(2000):

   _batch_size = 128

   batch = mnist.train.next_batch(_batch_size)

   if (i+1)%200 == 0:

       train_accuracy = sess.run(accuracy, feed_dict={

           _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size})

       # 已经迭代完成的 epoch 数: mnist.train.epochs_completed

       print("Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy))

   sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})

# 计算测试数据的准确率

print("test accuracy %g"% sess.run(accuracy, feed_dict={_X: mnist.test.images, y: mnist.test.labels,

                                                       keep_prob: 1.0, batch_size:mnist.test.images.shape[0]}))




相关文章
|
2月前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
116 4
|
4月前
|
机器学习/深度学习 人工智能 算法
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
文本分类识别系统。本系统使用Python作为主要开发语言,首先收集了10种中文文本数据集("体育类", "财经类", "房产类", "家居类", "教育类", "科技类", "时尚类", "时政类", "游戏类", "娱乐类"),然后基于TensorFlow搭建CNN卷积神经网络算法模型。通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型,并保存为本地的h5格式。然后使用Django开发Web网页端操作界面,实现用户上传一段文本识别其所属的类别。
123 1
【新闻文本分类识别系统】Python+卷积神经网络算法+人工智能+深度学习+计算机毕设项目+Django网页界面平台
|
3月前
|
存储 缓存 分布式计算
数据结构与算法学习一:学习前的准备,数据结构的分类,数据结构与算法的关系,实际编程中遇到的问题,几个经典算法问题
这篇文章是关于数据结构与算法的学习指南,涵盖了数据结构的分类、数据结构与算法的关系、实际编程中遇到的问题以及几个经典的算法面试题。
44 0
数据结构与算法学习一:学习前的准备,数据结构的分类,数据结构与算法的关系,实际编程中遇到的问题,几个经典算法问题
|
3月前
|
移动开发 算法 前端开发
前端常用算法全解:特征梳理、复杂度比较、分类解读与示例展示
前端常用算法全解:特征梳理、复杂度比较、分类解读与示例展示
35 0
|
4月前
|
机器学习/深度学习 算法 数据挖掘
决策树算法大揭秘:Python让你秒懂分支逻辑,精准分类不再难
【9月更文挑战第12天】决策树算法作为机器学习领域的一颗明珠,凭借其直观易懂和强大的解释能力,在分类与回归任务中表现出色。相比传统统计方法,决策树通过简单的分支逻辑实现了数据的精准分类。本文将借助Python和scikit-learn库,以鸢尾花数据集为例,展示如何使用决策树进行分类,并探讨其优势与局限。通过构建一系列条件判断,决策树不仅模拟了人类决策过程,还确保了结果的可追溯性和可解释性。无论您是新手还是专家,都能轻松上手,享受机器学习的乐趣。
57 9
|
5月前
|
机器学习/深度学习 API 异构计算
7.1.3.2、使用飞桨实现基于LSTM的情感分析模型的网络定义
该文章详细介绍了如何使用飞桨框架实现基于LSTM的情感分析模型,包括网络定义、模型训练、评估和预测的完整流程,并提供了相应的代码实现。
|
3月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
|
5月前
|
机器学习/深度学习
【机器学习】面试题:LSTM长短期记忆网络的理解?LSTM是怎么解决梯度消失的问题的?还有哪些其它的解决梯度消失或梯度爆炸的方法?
长短时记忆网络(LSTM)的基本概念、解决梯度消失问题的机制,以及介绍了包括梯度裁剪、改变激活函数、残差结构和Batch Normalization在内的其他方法来解决梯度消失或梯度爆炸问题。
203 2
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
RNN、LSTM、GRU神经网络构建人名分类器(三)
这个文本描述了一个使用RNN(循环神经网络)、LSTM(长短期记忆网络)和GRU(门控循环单元)构建的人名分类器的案例。案例的主要目的是通过输入一个人名来预测它最可能属于哪个国家。这个任务在国际化的公司中很重要,因为可以自动为用户注册时提供相应的国家或地区选项。
|
7月前
|
机器学习/深度学习 数据采集
RNN、LSTM、GRU神经网络构建人名分类器(一)
这个文本描述了一个使用RNN(循环神经网络)、LSTM(长短期记忆网络)和GRU(门控循环单元)构建的人名分类器的案例。案例的主要目的是通过输入一个人名来预测它最可能属于哪个国家。这个任务在国际化的公司中很重要,因为可以自动为用户注册时提供相应的国家或地区选项。

热门文章

最新文章