目录
输出结果
代码设计
1. import tensorflow as tf 2. import numpy as np 3. import matplotlib.pyplot as plt 4. 5. BATCH_START = 0 6. TIME_STEPS = 20 7. BATCH_SIZE = 50 8. INPUT_SIZE = 1 9. OUTPUT_SIZE = 1 10. CELL_SIZE = 10 11. LR = 0.006 12. BATCH_START_TEST = 0 13. 14. def get_batch(): 15. global BATCH_START, TIME_STEPS 16. # xs shape (50batch, 20steps) 17. xs = np.arange(BATCH_START, BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10*np.pi) 18. seq = np.sin(xs) 19. res = np.cos(xs) 20. BATCH_START += TIME_STEPS 21. return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs] 22. 23. 24. class LSTMRNN(object): 25. def __init__(self, n_steps, input_size, output_size, cell_size, batch_size): 26. self.n_steps = n_steps 27. self.input_size = input_size 28. self.output_size = output_size 29. self.cell_size = cell_size 30. self.batch_size = batch_size 31. with tf.name_scope('inputs'): 32. self.xs = tf.placeholder(tf.float32, [None, n_steps, input_size], name='xs') 33. self.ys = tf.placeholder(tf.float32, [None, n_steps, output_size], name='ys') 34. with tf.variable_scope('in_hidden'): 35. self.add_input_layer() 36. with tf.variable_scope('LSTM_cell'): 37. self.add_cell() 38. with tf.variable_scope('out_hidden'): 39. self.add_output_layer() 40. with tf.name_scope('cost'): 41. self.compute_cost() 42. with tf.name_scope('train'): 43. self.train_op = tf.train.AdamOptimizer(LR).minimize(self.cost) 44. 45. def add_input_layer(self,): 46. l_in_x = tf.reshape(self.xs, [-1, self.input_size], name='2_2D') 47. Ws_in = self._weight_variable([self.input_size, self.cell_size]) 48. bs_in = self._bias_variable([self.cell_size,]) 49. with tf.name_scope('Wx_plus_b'): 50. l_in_y = tf.matmul(l_in_x, Ws_in) + bs_in 51. self.l_in_y = tf.reshape(l_in_y, [-1, self.n_steps, self.cell_size], name='2_3D') 52. 53. def add_cell(self): 54. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.cell_size, forget_bias=1.0, state_is_tuple=True) 55. with tf.name_scope('initial_state'): 56. self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32) 57. self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn( 58. lstm_cell, self.l_in_y, initial_state=self.cell_init_state, time_major=False) 59. 60. def add_output_layer(self): 61. l_out_x = tf.reshape(self.cell_outputs, [-1, self.cell_size], name='2_2D') 62. Ws_out = self._weight_variable([self.cell_size, self.output_size]) 63. bs_out = self._bias_variable([self.output_size, ]) 64. with tf.name_scope('Wx_plus_b'): 65. self.pred = tf.matmul(l_out_x, Ws_out) + bs_out 66. 67. def compute_cost(self): 68. losses = tf.contrib.legacy_seq2seq.sequence_loss_by_example( 69. [tf.reshape(self.pred, [-1], name='reshape_pred')], 70. [tf.reshape(self.ys, [-1], name='reshape_target')], 71. [tf.ones([self.batch_size * self.n_steps], dtype=tf.float32)], 72. average_across_timesteps=True, 73. softmax_loss_function=self.ms_error, 74. name='losses' 75. ) 76. with tf.name_scope('average_cost'): 77. self.cost = tf.div( 78. tf.reduce_sum(losses, name='losses_sum'), 79. self.batch_size, 80. name='average_cost') 81. tf.summary.scalar('cost', self.cost) 82. 83. def ms_error(self, y_target, y_pre): 84. return tf.square(tf.sub(y_target, y_pre)) 85. 86. def _weight_variable(self, shape, name='weights'): 87. initializer = tf.random_normal_initializer(mean=0., stddev=1.,) 88. return tf.get_variable(shape=shape, initializer=initializer, name=name) 89. 90. def _bias_variable(self, shape, name='biases'): 91. initializer = tf.constant_initializer(0.1) 92. return tf.get_variable(name=name, shape=shape, initializer=initializer) 93. 94. if __name__ == '__main__': 95. model = LSTMRNN(TIME_STEPS, INPUT_SIZE, OUTPUT_SIZE, CELL_SIZE, BATCH_SIZE) 96. sess = tf.Session() 97. merged=tf.summary.merge_all() 98. writer=tf.summary.FileWriter("niu0127/logs0127",sess.graph) 99. sess.run(tf.initialize_all_variables()) 100. 101. plt.ion() 102. plt.show() 103. 104. for i in range(200): 105. seq, res, xs = get_batch() 106. if i == 0: 107. feed_dict = { 108. model.xs: seq, 109. model.ys: res, 110. } 111. else: 112. feed_dict = { 113. model.xs: seq, 114. model.ys: res, 115. model.cell_init_state: state 116. } 117. _, cost, state, pred = sess.run( 118. [model.train_op, model.cost, model.cell_final_state, model.pred], 119. feed_dict=feed_dict) 120. 121. plt.plot(xs[0,:],res[0].flatten(),'r',xs[0,:],pred.flatten()[:TIME_STEPS],'g--') 122. plt.title('Matplotlib,RNN,Efficient learning,Approach,Cosx --Jason Niu') 123. plt.ylim((-1.2,1.2)) 124. plt.draw() 125. plt.pause(0.1)