TF之LSTM:利用基于顺序的LSTM回归算法对DIY数据集sin曲线(蓝虚)预测cos(红实)(matplotlib动态演示)—daiding

简介: TF之LSTM:利用基于顺序的LSTM回归算法对DIY数据集sin曲线(蓝虚)预测cos(红实)(matplotlib动态演示)—daiding


目录

输出结果

代码设计


 

 

 

 

输出结果

 

 

代码设计

1. import tensorflow as tf
2. import numpy as np
3. import matplotlib.pyplot as plt
4. 
5. BATCH_START = 0
6. TIME_STEPS = 20
7. BATCH_SIZE = 50
8. INPUT_SIZE = 1
9. OUTPUT_SIZE = 1
10. CELL_SIZE = 10
11. LR = 0.006
12. BATCH_START_TEST = 0
13. 
14. def get_batch():    
15. global BATCH_START, TIME_STEPS
16. # xs shape (50batch, 20steps)
17.     xs = np.arange(BATCH_START, BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10*np.pi)
18.     seq = np.sin(xs)
19.     res = np.cos(xs)
20.     BATCH_START += TIME_STEPS
21. return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs]
22. 
23. 
24. class LSTMRNN(object):  
25. def __init__(self, n_steps, input_size, output_size, cell_size, batch_size):
26.         self.n_steps = n_steps
27.         self.input_size = input_size
28.         self.output_size = output_size
29.         self.cell_size = cell_size
30.         self.batch_size = batch_size
31. with tf.name_scope('inputs'):
32.             self.xs = tf.placeholder(tf.float32, [None, n_steps, input_size], name='xs')
33.             self.ys = tf.placeholder(tf.float32, [None, n_steps, output_size], name='ys')
34. with tf.variable_scope('in_hidden'):
35.             self.add_input_layer()
36. with tf.variable_scope('LSTM_cell'):
37.             self.add_cell()
38. with tf.variable_scope('out_hidden'):
39.             self.add_output_layer()
40. with tf.name_scope('cost'):
41.             self.compute_cost()       
42. with tf.name_scope('train'):
43.             self.train_op = tf.train.AdamOptimizer(LR).minimize(self.cost)
44. 
45. def add_input_layer(self,):  
46.         l_in_x = tf.reshape(self.xs, [-1, self.input_size], name='2_2D') 
47.         Ws_in = self._weight_variable([self.input_size, self.cell_size])
48.         bs_in = self._bias_variable([self.cell_size,])
49. with tf.name_scope('Wx_plus_b'):
50.             l_in_y = tf.matmul(l_in_x, Ws_in) + bs_in
51.         self.l_in_y = tf.reshape(l_in_y, [-1, self.n_steps, self.cell_size], name='2_3D')
52. 
53. def add_cell(self):       
54.         lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.cell_size, forget_bias=1.0, state_is_tuple=True)
55. with tf.name_scope('initial_state'): 
56.             self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32) 
57.         self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn( 
58.             lstm_cell, self.l_in_y, initial_state=self.cell_init_state, time_major=False)  
59. 
60. def add_output_layer(self):  
61.         l_out_x = tf.reshape(self.cell_outputs, [-1, self.cell_size], name='2_2D')
62.         Ws_out = self._weight_variable([self.cell_size, self.output_size])
63.         bs_out = self._bias_variable([self.output_size, ])
64. with tf.name_scope('Wx_plus_b'):
65.             self.pred = tf.matmul(l_out_x, Ws_out) + bs_out
66. 
67. def compute_cost(self):
68.         losses = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
69.             [tf.reshape(self.pred, [-1], name='reshape_pred')],
70.             [tf.reshape(self.ys, [-1], name='reshape_target')],
71.             [tf.ones([self.batch_size * self.n_steps], dtype=tf.float32)],
72.             average_across_timesteps=True,
73.             softmax_loss_function=self.ms_error,
74.             name='losses'
75.         )
76. with tf.name_scope('average_cost'):
77.             self.cost = tf.div(
78.                 tf.reduce_sum(losses, name='losses_sum'),
79.                 self.batch_size,
80.                 name='average_cost')
81.             tf.summary.scalar('cost', self.cost)
82. 
83. def ms_error(self, y_target, y_pre):  
84. return tf.square(tf.sub(y_target, y_pre)) 
85. 
86. def _weight_variable(self, shape, name='weights'):
87.         initializer = tf.random_normal_initializer(mean=0., stddev=1.,)
88. return tf.get_variable(shape=shape, initializer=initializer, name=name)
89. 
90. def _bias_variable(self, shape, name='biases'):
91.         initializer = tf.constant_initializer(0.1)
92. return tf.get_variable(name=name, shape=shape, initializer=initializer)
93. 
94. if __name__ == '__main__':  
95.     model = LSTMRNN(TIME_STEPS, INPUT_SIZE, OUTPUT_SIZE, CELL_SIZE, BATCH_SIZE)
96.     sess = tf.Session()
97.     merged=tf.summary.merge_all()
98.     writer=tf.summary.FileWriter("niu0127/logs0127",sess.graph)
99.     sess.run(tf.initialize_all_variables())
100. 
101. plt.ion()  
102. plt.show() 
103. 
104. for i in range(200):
105.       seq, res, xs = get_batch() 
106. if i == 0:
107.           feed_dict = {
108.                   model.xs: seq,
109.                   model.ys: res,
110.           }
111. else:
112.           feed_dict = {
113.               model.xs: seq,
114.               model.ys: res,
115.               model.cell_init_state: state  
116.           }
117.       _, cost, state, pred = sess.run(
118.           [model.train_op, model.cost, model.cell_final_state, model.pred],
119.           feed_dict=feed_dict)
120. 
121.       plt.plot(xs[0,:],res[0].flatten(),'r',xs[0,:],pred.flatten()[:TIME_STEPS],'g--')
122.       plt.title('Matplotlib,RNN,Efficient learning,Approach,Cosx --Jason Niu')
123.       plt.ylim((-1.2,1.2))
124.       plt.draw()
125.       plt.pause(0.1)

 


目录
打赏
0
0
0
0
1043
分享
相关文章
DSA与RSA的区别、ECC(椭圆曲线数字签名算法(ECDSA))
DSA与RSA的区别、ECC(椭圆曲线数字签名算法(ECDSA))
759 0
基于NURBS曲线的数据拟合算法matlab仿真
本程序基于NURBS曲线实现数据拟合,适用于计算机图形学、CAD/CAM等领域。通过控制顶点和权重,精确表示复杂形状,特别适合真实对象建模和数据点光滑拟合。程序在MATLAB2022A上运行,展示了T1至T7的测试结果,无水印输出。核心算法采用梯度下降等优化技术调整参数,最小化误差函数E,确保迭代收敛,提供高质量的拟合效果。
Python通过matplotlib动态绘图实现中美GDP历年对比趋势动图
随着中国的各种实力的提高,经常在各种媒体上看到中国与各个国家历年的各种指标数据的对比,为了更清楚的展示历年的发展趋势,有的还做成了动图,看到中国各种指标数据的近年的不断逆袭,心中的自豪感油然而生。今天通过Python来实现matplotlib的动态绘图,将中美两国近年的GDP做个对比,展示中国GPD对美国的追赶态势,相信不久的将来中国的GDP数据将稳超美国。
226 2
【计算机网络】—— IP协议及动态路由算法(下)
【计算机网络】—— IP协议及动态路由算法(下)
123 0
【计算机网络】—— IP协议及动态路由算法(上)
【计算机网络】—— IP协议及动态路由算法(上)
418 0
【视频】时间序列分类方法:动态时间规整算法DTW和R语言实现
【视频】时间序列分类方法:动态时间规整算法DTW和R语言实现

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等