#coding=utf-8
import numpy as np
import tensorflow as tf
import matplotlib as mpl
mpl.use('Agg')
from matplotlib import pyplot as plt
learn=tf.contrib.learn
HIDDEN_SIZE=30 #LSTM中隐藏节点的个数
NUM_LAYERS=2 #LSTM层数
TIMESTEPS=10 #循环神经网络的截断长度
TRAINING_STEPS=10000 #训练轮数
BATCH_SIZE=32 #batch大小
TRAINING_EXAMPLES=10000 #训练数据个数
TESTING_EXAMPLES=1000 #测试数据个数
SAMPLE_GAP=0.01 #采样间隔
def generate_data(seq):
X=[]
y=[]
for i in range(len(seq)-TIMESTEPS-1):
X.append([seq[i:i+TIMESTEPS]])
y.append([seq[i+TIMESTEPS]])
return np.array(X,dtype=np.float32),np.array(y,dtype=np.float32)
def lstm_model(X,y):
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
cell=tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*NUM_LAYERS)
x_=tf.unstack(X,axis=1)
output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32)
output=output[-1]
prediction,loss=learn.models.linear_regression(output,y)
train_op=tf.contrib.layers.optimize_loss(loss,tf.contrib.framework.get_global_step(),optimizer="Adagrad",learning_rate=0.1)
return prediction,loss,train_op
regressor=learn.Estimator(model_fn=lstm_model)
test_start=TRAINING_EXAMPLES*SAMPLE_GAP
test_end=(TRAINING_EXAMPLES+TESTING_EXAMPLES)*SAMPLE_GAP
train_X,train_y=generate_data(np.sin(np.linspace(0,test_start,TRAINING_EXAMPLES,dtype=np.float32)))
test_X,test_y=generate_data(np.sin(np.linspace(test_start,test_end,TESTING_EXAMPLES,dtype=np.float32)))
regressor.fit(train_X,train_y,batch_size=BATCH_SIZE,steps=TRAINING_STEPS)
predicted=[[pred] for pred in regressor.predict(test_X)]
rmse=np.sqrt(((predicted-test_y)**2).mean(axis=0))
print('Mean square error is: %f'%rmse[0])
fig=plt.figure()
plot_predicted=plt.plot(predicted,label='predicted')
plot_test=plt.plot(test_y,label='real_sin')
plt.legend([plot_predicted,plot_test],['predicted','real_sin'])
fig.savefig('sin.png')
这是我的代码,预测正弦函数的深度学习算法。
提示报错ValueError: Shape (10, ?) must have rank at least 3
应该是output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32)这一行开始出现了问题。
请教一下诸位大神,这个应该怎么解决
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。
请问你解决这个问题了吗?我也报错相同的问题