视频教学:交通流量预测LSTM实战详细教学_哔哩哔哩_bilibili
结果展示:
完整代码:
import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt from torch import nn from torch.autograd import Variable data_csv = pd.read_csv("上海中山公园地铁客流2015年数据.csv",encoding = 'gb2312') print(data_csv.head()) plt.figure(figsize=(12,4)) plt.plot(data_csv[0:300]) plt.show() # 创建训练和测试LSTM模型的数据集,是通过前面测试的15min时间粒度的客流量来预测当前时间粒度的客流量,我们令前2个时间粒度的客流数据是输入,对应代码中的step=2, # 把当前时间粒度的客流数据作为输出,划分数据