开发者社区> 问答> 正文

tensorflow LSTM时间序列预测问题?报错

#coding=utf-8
import numpy as np
import tensorflow as tf
import matplotlib as mpl
mpl.use('Agg')
from matplotlib import pyplot as plt



learn=tf.contrib.learn

HIDDEN_SIZE=30    #LSTM中隐藏节点的个数
NUM_LAYERS=2    #LSTM层数
TIMESTEPS=10    #循环神经网络的截断长度
TRAINING_STEPS=10000    #训练轮数
BATCH_SIZE=32       #batch大小
TRAINING_EXAMPLES=10000      #训练数据个数
TESTING_EXAMPLES=1000    #测试数据个数
SAMPLE_GAP=0.01       #采样间隔

def generate_data(seq):
    X=[]
    y=[]
    for i in range(len(seq)-TIMESTEPS-1):
        X.append([seq[i:i+TIMESTEPS]])
        y.append([seq[i+TIMESTEPS]])
    return np.array(X,dtype=np.float32),np.array(y,dtype=np.float32)

def lstm_model(X,y):
    lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
    cell=tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*NUM_LAYERS)

    x_=tf.unstack(X,axis=1)

    output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32)

    output=output[-1]
    prediction,loss=learn.models.linear_regression(output,y)
    train_op=tf.contrib.layers.optimize_loss(loss,tf.contrib.framework.get_global_step(),optimizer="Adagrad",learning_rate=0.1)
    return prediction,loss,train_op

regressor=learn.Estimator(model_fn=lstm_model)

test_start=TRAINING_EXAMPLES*SAMPLE_GAP
test_end=(TRAINING_EXAMPLES+TESTING_EXAMPLES)*SAMPLE_GAP
train_X,train_y=generate_data(np.sin(np.linspace(0,test_start,TRAINING_EXAMPLES,dtype=np.float32)))
test_X,test_y=generate_data(np.sin(np.linspace(test_start,test_end,TESTING_EXAMPLES,dtype=np.float32)))

regressor.fit(train_X,train_y,batch_size=BATCH_SIZE,steps=TRAINING_STEPS)

predicted=[[pred] for pred in regressor.predict(test_X)]

rmse=np.sqrt(((predicted-test_y)**2).mean(axis=0))
print('Mean square error is: %f'%rmse[0])

fig=plt.figure()
plot_predicted=plt.plot(predicted,label='predicted')
plot_test=plt.plot(test_y,label='real_sin')
plt.legend([plot_predicted,plot_test],['predicted','real_sin'])
fig.savefig('sin.png')

这是我的代码,预测正弦函数的深度学习算法。

提示报错ValueError: Shape (10, ?) must have rank at least 3

应该是output,_=tf.nn.dynamic_rnn(cell,x_,dtype=tf.float32)这一行开始出现了问题。

请教一下诸位大神,这个应该怎么解决

展开
收起
爱吃鱼的程序员 2020-06-08 13:26:46 764 0
1 条回答
写回答
取消 提交回答
  • https://developer.aliyun.com/profile/5yerqm5bn5yqg?spm=a2c6h.12873639.0.0.6eae304abcjaIB

    请问你解决这个问题了吗?我也报错相同的问题

    2020-06-08 13:27:04
    赞同 展开评论 打赏
问答排行榜
最热
最新

相关电子书

更多
使用TensorFlow搭建智能开发系统自动生成App UI 立即下载
从零到一:IOS平台TensorFlow入门及应用详解 立即下载
从零到一:IOS平台TensorFlow入门及应用详解(附源 立即下载