TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)-阿里云开发者社区

开发者社区> 一个处女座的程序猿> 正文

TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

简介: TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
+关注继续查看

输出结果


第 0 accuracy 0.125

第 20 accuracy 0.6484375

第 40 accuracy 0.78125

第 60 accuracy 0.9296875

第 80 accuracy 0.8671875

第 100 accuracy 0.90625

第 120 accuracy 0.8671875

第 140 accuracy 0.8671875

第 160 accuracy 0.8671875

第 180 accuracy 0.921875

第 200 accuracy 0.890625

第 220 accuracy 0.953125

第 240 accuracy 0.921875

第 260 accuracy 0.9296875

第 280 accuracy 0.9140625

第 300 accuracy 0.921875

第 320 accuracy 0.9609375

第 340 accuracy 0.953125

第 360 accuracy 0.984375

第 380 accuracy 0.921875

第 400 accuracy 0.9453125

第 420 accuracy 0.921875

第 440 accuracy 0.9296875

第 460 accuracy 0.96875

第 480 accuracy 0.984375

第 500 accuracy 0.96875

第 520 accuracy 0.953125

第 540 accuracy 0.96875

第 560 accuracy 0.953125

第 580 accuracy 0.9921875

第 600 accuracy 0.984375

第 620 accuracy 0.953125

第 640 accuracy 0.953125

第 660 accuracy 0.9921875

第 680 accuracy 0.96875

第 700 accuracy 0.9765625

第 720 accuracy 0.96875

第 740 accuracy 0.9921875

第 760 accuracy 0.984375

第 780 accuracy 0.953125

image.png



设计思路

image.png


代码设计

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

lr=0.001                  

training_iters=100000      

batch_size=128              

n_inputs=28    

n_steps=28      

n_hidden_units=128

n_classes=10        

x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])

y=tf.placeholder(tf.float32, [None,n_classes])

weights ={

   'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),

   'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),

   }

biases ={

   'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),

   'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),

   }

def RNN(X,weights,biases):

   X=tf.reshape(X,[-1,n_inputs])

   X_in=tf.matmul(X,weights['in'])+biases['in']  

   X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])

   lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)

   __init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)

   outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)

         

   outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))

   results=tf.matmul(outputs[-1],weights['out'])+biases['out']

   return results

pred =RNN(x,weights,biases)

cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))

train_op=tf.train.AdamOptimizer(lr).minimize(cost)                

correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))              

accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))            

<br>

with tf.Session() as sess:

   sess.run(init)

   step=0

   while step*batch_size < training_iters:                

       batch_xs,batch_ys=mnist.train.next_batch(batch_size)

       batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])

       sess.run([train_op],feed_dict={

           x:batch_xs,

           y:batch_ys,})

       if step%20==0:                                        

           print(sess.run(accuracy,feed_dict={

               x:batch_xs,

               y:batch_ys,}))

       step+=1


image.png

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

相关文章
PostgreSQL不同模式(SCHEMA)之间迁移数据
PostgreSQL不同模式(SCHEMA)之间迁移数据。
8043 0
Python中关于类和函数的初体验之"__init__"和"__str__"不是"_init_"和"_str_"
刚刚接触Python,今天就是怎么也调试不过去了,上网上查直到晚上才查到一个有效信息,真是坑啊!原来Python中的这些“魔法”方法的命名里就有陷阱…… 上图中的那两个红圈圈,一定要记住哦,这些Python自带的方法,比如str和init前后都是两个"_",写一个"_"按F5运行肯定有问题! ...
747 0
DL之Attention:基于ClutteredMNIST手写数字图片数据集分别利用CNN_Init、ST_CNN算法(CNN+SpatialTransformer)实现多分类预测(二)
DL之Attention:基于ClutteredMNIST手写数字图片数据集分别利用CNN_Init、ST_CNN算法(CNN+SpatialTransformer)实现多分类预测
21 0
DL之Attention:基于ClutteredMNIST手写数字图片数据集分别利用CNN_Init、ST_CNN算法(CNN+SpatialTransformer)实现多分类预测(一)
DL之Attention:基于ClutteredMNIST手写数字图片数据集分别利用CNN_Init、ST_CNN算法(CNN+SpatialTransformer)实现多分类预测
31 0
[Unity3d]Unity系统自带函数生命周期以及相互关系
Unity脚本从唤醒到销毁都有着一套比较完善的生命周期,添加任何脚本都要遵守生命周期法则! 接下来介绍几种系统自调用的重要方法。首先要我们先来说明一下它们的执行顺序: Awake --> Start --> Update --> FixedUpdate --> LateUpdate -->OnGUI -->Reset --> OnDisable -->OnDestroy 下面我们针对每一个方法进行详细的说明: 1.Awake:用于在游戏开始之前初始化变量或游戏状态。
1019 0
写一个函数对字符串数组排序,使所有变位词都相邻
题目 写一个函数对字符串数组排序,使得所有的变位词都相邻。 解答 首先,要弄清楚什么是变位词。变位词就是组成的字母相同,但顺序不一样的单词。 比如说:live和evil就是一对变位词。OK,那么这道题目的意思就很清楚了, 它并不要求我们将字符串数组中的字符串按字典序排序,否则我们直接调用STL中的sort 函数就可以了。
729 0
一站式数据采集存储的利器:阿里云InfluxDB®️数据采集服务
随着时序数据的飞速增长,时序数据库不仅需要解决系统的稳定性和性能问题,还需实现数据从采集到分析的链路打通,才能让时序数据真正产生价值。
1641 0
+关注
一个处女座的程序猿
国内互联网圈知名博主、人工智能领域优秀创作者,全球最大中文IT社区博客专家、CSDN开发者联盟生态成员、中国开源社区专家、华为云社区专家、51CTO社区专家、Python社区专家等,曾受邀采访和评审十多次。仅在国内的CSDN平台,博客文章浏览量超过2500万,拥有超过57万的粉丝。
1701
文章
0
问答
文章排行榜
最热
最新
相关电子书
更多
文娱运维技术
立即下载
《SaaS模式云原生数据仓库应用场景实践》
立即下载
《看见新力量:二》电子书
立即下载