TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

简介: TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

输出结果


第 0 accuracy 0.125

第 20 accuracy 0.6484375

第 40 accuracy 0.78125

第 60 accuracy 0.9296875

第 80 accuracy 0.8671875

第 100 accuracy 0.90625

第 120 accuracy 0.8671875

第 140 accuracy 0.8671875

第 160 accuracy 0.8671875

第 180 accuracy 0.921875

第 200 accuracy 0.890625

第 220 accuracy 0.953125

第 240 accuracy 0.921875

第 260 accuracy 0.9296875

第 280 accuracy 0.9140625

第 300 accuracy 0.921875

第 320 accuracy 0.9609375

第 340 accuracy 0.953125

第 360 accuracy 0.984375

第 380 accuracy 0.921875

第 400 accuracy 0.9453125

第 420 accuracy 0.921875

第 440 accuracy 0.9296875

第 460 accuracy 0.96875

第 480 accuracy 0.984375

第 500 accuracy 0.96875

第 520 accuracy 0.953125

第 540 accuracy 0.96875

第 560 accuracy 0.953125

第 580 accuracy 0.9921875

第 600 accuracy 0.984375

第 620 accuracy 0.953125

第 640 accuracy 0.953125

第 660 accuracy 0.9921875

第 680 accuracy 0.96875

第 700 accuracy 0.9765625

第 720 accuracy 0.96875

第 740 accuracy 0.9921875

第 760 accuracy 0.984375

第 780 accuracy 0.953125

image.png



设计思路

image.png


代码设计

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

lr=0.001                  

training_iters=100000      

batch_size=128              

n_inputs=28    

n_steps=28      

n_hidden_units=128

n_classes=10        

x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])

y=tf.placeholder(tf.float32, [None,n_classes])

weights ={

   'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),

   'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),

   }

biases ={

   'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),

   'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),

   }

def RNN(X,weights,biases):

   X=tf.reshape(X,[-1,n_inputs])

   X_in=tf.matmul(X,weights['in'])+biases['in']  

   X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])

   lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)

   __init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)

   outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)

       

   outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))

   results=tf.matmul(outputs[-1],weights['out'])+biases['out']

   return results

pred =RNN(x,weights,biases)

cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))

train_op=tf.train.AdamOptimizer(lr).minimize(cost)                

correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))              

accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))            

<br>

with tf.Session() as sess:

   sess.run(init)

   step=0

   while step*batch_size < training_iters:                

       batch_xs,batch_ys=mnist.train.next_batch(batch_size)

       batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])

       sess.run([train_op],feed_dict={

           x:batch_xs,

           y:batch_ys,})

       if step%20==0:                                        

           print(sess.run(accuracy,feed_dict={

               x:batch_xs,

               y:batch_ys,}))

       step+=1


image.png

相关文章
|
2月前
|
XML JavaScript 前端开发
学习react基础(1)_虚拟dom、diff算法、函数和class创建组件
本文介绍了React的核心概念,包括虚拟DOM、Diff算法以及如何通过函数和类创建React组件。
25 2
|
3月前
|
数据采集 机器学习/深度学习 算法
【python】python客户信息审计风险决策树算法分类预测(源码+数据集+论文)【独一无二】
【python】python客户信息审计风险决策树算法分类预测(源码+数据集+论文)【独一无二】
|
3月前
|
算法
【Azure Developer】完成算法第4版书中,第一节基础编码中的数组函数 histogrm()
【Azure Developer】完成算法第4版书中,第一节基础编码中的数组函数 histogrm()
|
4月前
|
算法 Python
`scipy.optimize`模块提供了许多用于优化问题的函数和算法。这些算法可以用于找到函数的最小值、最大值、零点等。
`scipy.optimize`模块提供了许多用于优化问题的函数和算法。这些算法可以用于找到函数的最小值、最大值、零点等。
|
4月前
|
算法 安全 数据安全/隐私保护
支付系统---微信支付09------数字签名,现在Bob想要给Pink写一封信,信件的内容不需要加密,怎样能够保证信息的完整性,使用信息完整性的主要手段是摘要算法,散列函数,哈希函数,H称为数据指纹
支付系统---微信支付09------数字签名,现在Bob想要给Pink写一封信,信件的内容不需要加密,怎样能够保证信息的完整性,使用信息完整性的主要手段是摘要算法,散列函数,哈希函数,H称为数据指纹
|
5月前
|
算法 vr&ar
技术好文共享:遗传算法解决函数优化
技术好文共享:遗传算法解决函数优化
|
3月前
|
机器学习/深度学习 API 异构计算
7.1.3.2、使用飞桨实现基于LSTM的情感分析模型的网络定义
该文章详细介绍了如何使用飞桨框架实现基于LSTM的情感分析模型,包括网络定义、模型训练、评估和预测的完整流程,并提供了相应的代码实现。
|
14天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
|
3月前
|
机器学习/深度学习
【机器学习】面试题:LSTM长短期记忆网络的理解?LSTM是怎么解决梯度消失的问题的?还有哪些其它的解决梯度消失或梯度爆炸的方法?
长短时记忆网络(LSTM)的基本概念、解决梯度消失问题的机制,以及介绍了包括梯度裁剪、改变激活函数、残差结构和Batch Normalization在内的其他方法来解决梯度消失或梯度爆炸问题。
100 2
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
RNN、LSTM、GRU神经网络构建人名分类器(三)
这个文本描述了一个使用RNN(循环神经网络)、LSTM(长短期记忆网络)和GRU(门控循环单元)构建的人名分类器的案例。案例的主要目的是通过输入一个人名来预测它最可能属于哪个国家。这个任务在国际化的公司中很重要,因为可以自动为用户注册时提供相应的国家或地区选项。