#coding=utf-8 import numpy as np import tensorflow as tf import matplotlib as mpl mpl.use('Agg') from matplotlib import pyplot as plt learn=tf.contrib.learn HIDDEN_SIZE=30 #LSTM中隐藏节点的个数 NUM_LAYERS=2 #LSTM层数 TIMESTEPS=10 #循环神经网络的截断长度 TRAINING_STEPS=10000 #训练轮数 BATCH_SIZE=32 #batch大小 TRAINING_EXAMPLES=10000 #训练数据个数 TESTING_EXAMPLES=1000 #测试数据个数 SAMPLE_GAP=0.01 #采样间隔 def generate_data(seq): X=[] y=[] for i in range(len(seq)-TIMESTEPS-1): X.append([seq[i:i+TIMESTEPS]]) y.append([seq[i+TIMESTEPS]]) return np.array(X,dtype=np.float32),np.array(y,dtype=np.float32) def lstm_model(X,y): lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) cell=tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*NUM_LAYERS) x_=tf.unstack(X,axis=1) output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32) output=output[-1] prediction,loss=learn.models.linear_regression(output,y) train_op=tf.contrib.layers.optimize_loss(loss,tf.contrib.framework.get_global_step(),optimizer="Adagrad",learning_rate=0.1) return prediction,loss,train_op regressor=learn.Estimator(model_fn=lstm_model) test_start=TRAINING_EXAMPLES*SAMPLE_GAP test_end=(TRAINING_EXAMPLES+TESTING_EXAMPLES)*SAMPLE_GAP train_X,train_y=generate_data(np.sin(np.linspace(0,test_start,TRAINING_EXAMPLES,dtype=np.float32))) test_X,test_y=generate_data(np.sin(np.linspace(test_start,test_end,TESTING_EXAMPLES,dtype=np.float32))) regressor.fit(train_X,train_y,batch_size=BATCH_SIZE,steps=TRAINING_STEPS) predicted=[[pred] for pred in regressor.predict(test_X)] rmse=np.sqrt(((predicted-test_y)**2).mean(axis=0)) print('Mean square error is: %f'%rmse[0]) fig=plt.figure() plot_predicted=plt.plot(predicted,label='predicted') plot_test=plt.plot(test_y,label='real_sin') plt.legend([plot_predicted,plot_test],['predicted','real_sin']) fig.savefig('sin.png')
这是我的代码,预测正弦函数的深度学习算法。
提示报错ValueError: Shape (10, ?) must have rank at least 3
应该是output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32)这一行开始出现了问题。
请教一下诸位大神,这个应该怎么解决
请问你解决这个问题了吗?我也报错相同的问题
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。