TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

简介: TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

输出结果


第 0 accuracy 0.125

第 20 accuracy 0.6484375

第 40 accuracy 0.78125

第 60 accuracy 0.9296875

第 80 accuracy 0.8671875

第 100 accuracy 0.90625

第 120 accuracy 0.8671875

第 140 accuracy 0.8671875

第 160 accuracy 0.8671875

第 180 accuracy 0.921875

第 200 accuracy 0.890625

第 220 accuracy 0.953125

第 240 accuracy 0.921875

第 260 accuracy 0.9296875

第 280 accuracy 0.9140625

第 300 accuracy 0.921875

第 320 accuracy 0.9609375

第 340 accuracy 0.953125

第 360 accuracy 0.984375

第 380 accuracy 0.921875

第 400 accuracy 0.9453125

第 420 accuracy 0.921875

第 440 accuracy 0.9296875

第 460 accuracy 0.96875

第 480 accuracy 0.984375

第 500 accuracy 0.96875

第 520 accuracy 0.953125

第 540 accuracy 0.96875

第 560 accuracy 0.953125

第 580 accuracy 0.9921875

第 600 accuracy 0.984375

第 620 accuracy 0.953125

第 640 accuracy 0.953125

第 660 accuracy 0.9921875

第 680 accuracy 0.96875

第 700 accuracy 0.9765625

第 720 accuracy 0.96875

第 740 accuracy 0.9921875

第 760 accuracy 0.984375

第 780 accuracy 0.953125

image.png



设计思路

image.png


代码设计

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

lr=0.001                  

training_iters=100000      

batch_size=128              

n_inputs=28    

n_steps=28      

n_hidden_units=128

n_classes=10        

x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])

y=tf.placeholder(tf.float32, [None,n_classes])

weights ={

   'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),

   'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),

   }

biases ={

   'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),

   'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),

   }

def RNN(X,weights,biases):

   X=tf.reshape(X,[-1,n_inputs])

   X_in=tf.matmul(X,weights['in'])+biases['in']  

   X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])

   lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)

   __init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)

   outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)

       

   outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))

   results=tf.matmul(outputs[-1],weights['out'])+biases['out']

   return results

pred =RNN(x,weights,biases)

cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))

train_op=tf.train.AdamOptimizer(lr).minimize(cost)                

correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))              

accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))            

<br>

with tf.Session() as sess:

   sess.run(init)

   step=0

   while step*batch_size < training_iters:                

       batch_xs,batch_ys=mnist.train.next_batch(batch_size)

       batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])

       sess.run([train_op],feed_dict={

           x:batch_xs,

           y:batch_ys,})

       if step%20==0:                                        

           print(sess.run(accuracy,feed_dict={

               x:batch_xs,

               y:batch_ys,}))

       step+=1


image.png

相关文章
|
2月前
|
存储 机器学习/深度学习 算法
蓝桥杯练习题(三):Python组之算法训练提高综合五十题
蓝桥杯Python编程练习题的集合,涵盖了从基础到提高的多个算法题目及其解答。
127 3
蓝桥杯练习题(三):Python组之算法训练提高综合五十题
|
1月前
|
分布式计算 Java 开发工具
阿里云MaxCompute-XGBoost on Spark 极限梯度提升算法的分布式训练与模型持久化oss的实现与代码浅析
本文介绍了XGBoost在MaxCompute+OSS架构下模型持久化遇到的问题及其解决方案。首先简要介绍了XGBoost的特点和应用场景,随后详细描述了客户在将XGBoost on Spark任务从HDFS迁移到OSS时遇到的异常情况。通过分析异常堆栈和源代码,发现使用的`nativeBooster.saveModel`方法不支持OSS路径,而使用`write.overwrite().save`方法则能成功保存模型。最后提供了完整的Scala代码示例、Maven配置和提交命令,帮助用户顺利迁移模型存储路径。
|
2月前
|
机器学习/深度学习 算法 决策智能
【机器学习】揭秘深度学习优化算法:加速训练与提升性能
【机器学习】揭秘深度学习优化算法:加速训练与提升性能
|
2月前
|
算法 搜索推荐 Java
java 后端 使用 Graphics2D 制作海报,画echarts图,带工具类,各种细节:如头像切割成圆形,文字换行算法(完美实验success),解决画上文字、图片后不清晰问题
这篇文章介绍了如何使用Java后端技术,结合Graphics2D和Echarts等工具,生成包含个性化信息和图表的海报,并提供了详细的代码实现和GitHub项目链接。
156 0
java 后端 使用 Graphics2D 制作海报,画echarts图,带工具类,各种细节:如头像切割成圆形,文字换行算法(完美实验success),解决画上文字、图片后不清晰问题
|
2月前
|
算法 Java C++
【贪心算法】算法训练 ALGO-1003 礼物(C/C++)
【贪心算法】算法训练 ALGO-1003 礼物(C/C++)
【贪心算法】算法训练 ALGO-1003 礼物(C/C++)
|
2月前
|
算法 Java Linux
java制作海报一:java使用Graphics2D 在图片上写字,文字换行算法详解
这篇文章介绍了如何在Java中使用Graphics2D在图片上绘制文字,并实现自动换行的功能。
153 0
|
2月前
|
算法 C++
蓝桥 算法训练 共线(C++)
蓝桥 算法训练 共线(C++)
|
4月前
|
算法 搜索推荐
支付宝商业化广告算法问题之基于pretrain—>finetune范式的知识迁移中,finetune阶段全参数训练与部分参数训练的效果如何比较
支付宝商业化广告算法问题之基于pretrain—>finetune范式的知识迁移中,finetune阶段全参数训练与部分参数训练的效果如何比较
|
4月前
|
存储 算法
【C算法】编程初学者入门训练140道(1~20)
【C算法】编程初学者入门训练140道(1~20)
|
2天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
103 80

热门文章

最新文章