TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

简介: TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

输出结果


第 0 accuracy 0.125

第 20 accuracy 0.6484375

第 40 accuracy 0.78125

第 60 accuracy 0.9296875

第 80 accuracy 0.8671875

第 100 accuracy 0.90625

第 120 accuracy 0.8671875

第 140 accuracy 0.8671875

第 160 accuracy 0.8671875

第 180 accuracy 0.921875

第 200 accuracy 0.890625

第 220 accuracy 0.953125

第 240 accuracy 0.921875

第 260 accuracy 0.9296875

第 280 accuracy 0.9140625

第 300 accuracy 0.921875

第 320 accuracy 0.9609375

第 340 accuracy 0.953125

第 360 accuracy 0.984375

第 380 accuracy 0.921875

第 400 accuracy 0.9453125

第 420 accuracy 0.921875

第 440 accuracy 0.9296875

第 460 accuracy 0.96875

第 480 accuracy 0.984375

第 500 accuracy 0.96875

第 520 accuracy 0.953125

第 540 accuracy 0.96875

第 560 accuracy 0.953125

第 580 accuracy 0.9921875

第 600 accuracy 0.984375

第 620 accuracy 0.953125

第 640 accuracy 0.953125

第 660 accuracy 0.9921875

第 680 accuracy 0.96875

第 700 accuracy 0.9765625

第 720 accuracy 0.96875

第 740 accuracy 0.9921875

第 760 accuracy 0.984375

第 780 accuracy 0.953125

image.png



设计思路

image.png


代码设计

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

lr=0.001                  

training_iters=100000      

batch_size=128              

n_inputs=28    

n_steps=28      

n_hidden_units=128

n_classes=10        

x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])

y=tf.placeholder(tf.float32, [None,n_classes])

weights ={

   'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),

   'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),

   }

biases ={

   'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),

   'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),

   }

def RNN(X,weights,biases):

   X=tf.reshape(X,[-1,n_inputs])

   X_in=tf.matmul(X,weights['in'])+biases['in']  

   X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])

   lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)

   __init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)

   outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)

       

   outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))

   results=tf.matmul(outputs[-1],weights['out'])+biases['out']

   return results

pred =RNN(x,weights,biases)

cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))

train_op=tf.train.AdamOptimizer(lr).minimize(cost)                

correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))              

accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))            

<br>

with tf.Session() as sess:

   sess.run(init)

   step=0

   while step*batch_size < training_iters:                

       batch_xs,batch_ys=mnist.train.next_batch(batch_size)

       batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])

       sess.run([train_op],feed_dict={

           x:batch_xs,

           y:batch_ys,})

       if step%20==0:                                        

           print(sess.run(accuracy,feed_dict={

               x:batch_xs,

               y:batch_ys,}))

       step+=1


image.png

相关文章
|
3月前
|
算法 机器人 定位技术
【VRPTW】基于matlab秃鹰算法BES求解带时间窗的骑手外卖配送路径规划问题(目标函数:最优路径成本 含服务客户数量 服务时间 载量 路径长度)(Matlab代码实现)
【VRPTW】基于matlab秃鹰算法BES求解带时间窗的骑手外卖配送路径规划问题(目标函数:最优路径成本 含服务客户数量 服务时间 载量 路径长度)(Matlab代码实现)
|
2月前
|
传感器 算法 数据挖掘
基于协方差交叉(CI)的多传感器融合算法matlab仿真,对比单传感器和SCC融合
基于协方差交叉(CI)的多传感器融合算法,通过MATLAB仿真对比单传感器、SCC与CI融合在位置/速度估计误差(RMSE)及等概率椭圆上的性能。采用MATLAB2022A实现,结果表明CI融合在未知相关性下仍具鲁棒性,有效降低估计误差。
164 15
|
2月前
|
存储 并行计算 算法
【动态多目标优化算法】基于自适应启动策略的混合交叉动态约束多目标优化算法(MC-DCMOEA)求解CEC2023研究(Matlab代码实现)
【动态多目标优化算法】基于自适应启动策略的混合交叉动态约束多目标优化算法(MC-DCMOEA)求解CEC2023研究(Matlab代码实现)
128 4
|
2月前
|
机器学习/深度学习 传感器 算法
基于matlab瞬态三角哈里斯鹰算法TTHHO多无人机协同集群避障路径规划(目标函数:最低成本:路径、高度、威胁、转角)(Matlab代码实现)
基于matlab瞬态三角哈里斯鹰算法TTHHO多无人机协同集群避障路径规划(目标函数:最低成本:路径、高度、威胁、转角)(Matlab代码实现)
104 1
|
2月前
|
存储 算法 生物认证
基于Zhang-Suen算法的图像细化处理FPGA实现,包含testbench和matlab验证程序
本项目基于Zhang-Suen算法实现图像细化处理,支持FPGA与MATLAB双平台验证。通过对比,FPGA细化效果与MATLAB一致,可有效减少图像数据量,便于后续识别与矢量化处理。算法适用于字符识别、指纹识别等领域,配套完整仿真代码及操作说明。
|
3月前
|
机器学习/深度学习 算法 数据挖掘
【配送路径规划】基于螳螂虾算法MShOA求解带时间窗的骑手外卖配送路径规划问题(目标函数:最优路径成本 含服务客户数量 服务时间 载量 路径长度)研究(Matlab代码实现)
【配送路径规划】基于螳螂虾算法MShOA求解带时间窗的骑手外卖配送路径规划问题(目标函数:最优路径成本 含服务客户数量 服务时间 载量 路径长度)研究(Matlab代码实现)
117 0
|
3月前
|
算法 Python
【配送路径规划】基于遗传算法求解带时间窗的电动汽车配送路径规划(目标函数:最小成本;约束条件:续驶里程、额定载重量、数量、起始点)研究(Matlab代码实现)
【配送路径规划】基于遗传算法求解带时间窗的电动汽车配送路径规划(目标函数:最小成本;约束条件:续驶里程、额定载重量、数量、起始点)研究(Matlab代码实现)
100 0
|
4月前
|
存储 算法 数据安全/隐私保护
基于FPGA的图像退化算法verilog实现,分别实现横向和纵向运动模糊,包括tb和MATLAB辅助验证
本项目基于FPGA实现图像运动模糊算法,包含横向与纵向模糊处理流程。使用Vivado 2019.2与MATLAB 2022A,通过一维卷积模拟点扩散函数,完成图像退化处理,并可在MATLAB中预览效果。
|
30天前
|
数据采集 分布式计算 并行计算
mRMR算法实现特征选择-MATLAB
mRMR算法实现特征选择-MATLAB
102 2
|
2月前
|
传感器 机器学习/深度学习 编解码
MATLAB|主动噪声和振动控制算法——对较大的次级路径变化具有鲁棒性
MATLAB|主动噪声和振动控制算法——对较大的次级路径变化具有鲁棒性
171 3

热门文章

最新文章