目录
输出结果
LSTM代码
1. def LSTM(batch): 2. w_in=weights['in'] 3. b_in=biases['in'] 4. input_rnn=tf.matmul(input,w_in)+b_in 5. input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) 6. cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit) 7. init_state=cell.zero_state(batch,dtype=tf.float32) 8. output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) 9. output=tf.reshape(output_rnn,[-1,rnn_unit]) 10. w_out=weights['out'] 11. b_out=biases['out'] 12. pred=tf.matmul(output,w_out)+b_out 13. return pred,final_states