TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

简介: TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

输出结果


第 0 accuracy 0.125

第 20 accuracy 0.6484375

第 40 accuracy 0.78125

第 60 accuracy 0.9296875

第 80 accuracy 0.8671875

第 100 accuracy 0.90625

第 120 accuracy 0.8671875

第 140 accuracy 0.8671875

第 160 accuracy 0.8671875

第 180 accuracy 0.921875

第 200 accuracy 0.890625

第 220 accuracy 0.953125

第 240 accuracy 0.921875

第 260 accuracy 0.9296875

第 280 accuracy 0.9140625

第 300 accuracy 0.921875

第 320 accuracy 0.9609375

第 340 accuracy 0.953125

第 360 accuracy 0.984375

第 380 accuracy 0.921875

第 400 accuracy 0.9453125

第 420 accuracy 0.921875

第 440 accuracy 0.9296875

第 460 accuracy 0.96875

第 480 accuracy 0.984375

第 500 accuracy 0.96875

第 520 accuracy 0.953125

第 540 accuracy 0.96875

第 560 accuracy 0.953125

第 580 accuracy 0.9921875

第 600 accuracy 0.984375

第 620 accuracy 0.953125

第 640 accuracy 0.953125

第 660 accuracy 0.9921875

第 680 accuracy 0.96875

第 700 accuracy 0.9765625

第 720 accuracy 0.96875

第 740 accuracy 0.9921875

第 760 accuracy 0.984375

第 780 accuracy 0.953125

image.png



设计思路

image.png


代码设计

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

lr=0.001                  

training_iters=100000      

batch_size=128              

n_inputs=28    

n_steps=28      

n_hidden_units=128

n_classes=10        

x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])

y=tf.placeholder(tf.float32, [None,n_classes])

weights ={

   'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),

   'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),

   }

biases ={

   'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),

   'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),

   }

def RNN(X,weights,biases):

   X=tf.reshape(X,[-1,n_inputs])

   X_in=tf.matmul(X,weights['in'])+biases['in']  

   X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])

   lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)

   __init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)

   outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)

       

   outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))

   results=tf.matmul(outputs[-1],weights['out'])+biases['out']

   return results

pred =RNN(x,weights,biases)

cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))

train_op=tf.train.AdamOptimizer(lr).minimize(cost)                

correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))              

accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))            

<br>

with tf.Session() as sess:

   sess.run(init)

   step=0

   while step*batch_size < training_iters:                

       batch_xs,batch_ys=mnist.train.next_batch(batch_size)

       batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])

       sess.run([train_op],feed_dict={

           x:batch_xs,

           y:batch_ys,})

       if step%20==0:                                        

           print(sess.run(accuracy,feed_dict={

               x:batch_xs,

               y:batch_ys,}))

       step+=1


image.png

相关文章
|
2月前
|
算法 安全 Go
Go 语言中实现 RSA 加解密、签名验证算法
随着互联网的发展,安全需求日益增长。非对称加密算法RSA成为密码学中的重要代表。本文介绍如何使用Go语言和[forgoer/openssl](https://github.com/forgoer/openssl)库简化RSA加解密操作,包括秘钥生成、加解密及签名验证。该库还支持AES、DES等常用算法,安装简便,代码示例清晰易懂。
61 12
|
4月前
|
监控 算法 数据安全/隐私保护
基于三帧差算法的运动目标检测系统FPGA实现,包含testbench和MATLAB辅助验证程序
本项目展示了基于FPGA与MATLAB实现的三帧差算法运动目标检测。使用Vivado 2019.2和MATLAB 2022a开发环境,通过对比连续三帧图像的像素值变化,有效识别运动区域。项目包括完整无水印的运行效果预览、详细中文注释的代码及操作步骤视频,适合学习和研究。
|
4月前
|
算法 搜索推荐 Java
java 后端 使用 Graphics2D 制作海报,画echarts图,带工具类,各种细节:如头像切割成圆形,文字换行算法(完美实验success),解决画上文字、图片后不清晰问题
这篇文章介绍了如何使用Java后端技术,结合Graphics2D和Echarts等工具,生成包含个性化信息和图表的海报,并提供了详细的代码实现和GitHub项目链接。
203 0
java 后端 使用 Graphics2D 制作海报,画echarts图,带工具类,各种细节:如头像切割成圆形,文字换行算法(完美实验success),解决画上文字、图片后不清晰问题
|
5月前
|
算法 搜索推荐 开发者
别再让复杂度拖你后腿!Python 算法设计与分析实战,教你如何精准评估与优化!
在 Python 编程中,算法的性能至关重要。本文将带您深入了解算法复杂度的概念,包括时间复杂度和空间复杂度。通过具体的例子,如冒泡排序算法 (`O(n^2)` 时间复杂度,`O(1)` 空间复杂度),我们将展示如何评估算法的性能。同时,我们还会介绍如何优化算法,例如使用 Python 的内置函数 `max` 来提高查找最大值的效率,或利用哈希表将查找时间从 `O(n)` 降至 `O(1)`。此外,还将介绍使用 `timeit` 模块等工具来评估算法性能的方法。通过不断实践,您将能更高效地优化 Python 程序。
88 4
|
4月前
|
算法 Java Linux
java制作海报一:java使用Graphics2D 在图片上写字,文字换行算法详解
这篇文章介绍了如何在Java中使用Graphics2D在图片上绘制文字,并实现自动换行的功能。
237 0
|
7月前
|
算法 搜索推荐 开发者
别再让复杂度拖你后腿!Python 算法设计与分析实战,教你如何精准评估与优化!
【7月更文挑战第23天】在Python编程中,掌握算法复杂度—时间与空间消耗,是提升程序效能的关键。算法如冒泡排序($O(n^2)$时间/$O(1)$空间),或使用Python内置函数找最大值($O(n)$时间),需精确诊断与优化。数据结构如哈希表可将查找从$O(n)$降至$O(1)$。运用`timeit`模块评估性能,深入理解数据结构和算法,使Python代码更高效。持续实践与学习,精通复杂度管理。
71 9
|
6月前
|
机器学习/深度学习 算法 搜索推荐
支付宝商业化广告算法问题之在DNN模型中,特征的重要性如何评估
支付宝商业化广告算法问题之在DNN模型中,特征的重要性如何评估
|
7月前
|
机器学习/深度学习 数据采集 算法
Python实现贝叶斯岭回归模型(BayesianRidge算法)并使用K折交叉验证进行模型评估项目实战
Python实现贝叶斯岭回归模型(BayesianRidge算法)并使用K折交叉验证进行模型评估项目实战
|
7月前
|
文字识别 算法 Java
文本,保存图片09,一个可以用id作为图片名字的pom插件,利用雪花算法生成唯一的id
文本,保存图片09,一个可以用id作为图片名字的pom插件,利用雪花算法生成唯一的id
|
8月前
|
机器学习/深度学习 算法
GBDT算法超参数评估(二)
GBDT算法超参数评估关注决策树的不纯度指标,如基尼系数和信息熵,两者衡量数据纯度,影响树的生长。默认使用基尼系数,计算快速,而信息熵更敏感但计算慢。GBDT的弱评估器默认最大深度为3,限制了过拟合,不同于随机森林。由于Boosting的内在机制,过拟合控制更多依赖数据和参数如`max_features`。相比Bagging,Boosting通常不易过拟合。评估模型常用`cross_validate`和`KFold`交叉验证。