使用SARIMAX进行时间序列预测。
#!/usr/bin/env python
#-*-coding:utf-8-*-
#******************************************************************************
#****************Description:Time Series prediction using SARIMAX
#****************Author:Duan Tingyin
#****************Date:2018.02.14
#**************************************************
import pandas as pd
import matplotlib.pyplot as plt
import datetime
from statsmodels.tsa.api import SARIMAX
datapath = '../data/'
train_df = pd.read_csv(datapath+'[new] yancheng_train_20171226.csv')
testA_df = pd.read_csv(datapath+'yancheng_testA_20171225.csv')
testB_df = pd.read_csv(datapath+'yancheng_testB_20180224.csv')
train_class = train_df.groupby(['sale_date','class_id'])['sale_quantity'].sum().to_frame().reset_index()
train_class.head()
def plt_class(data,x,y,class_id):
this_class_id=data[data.class_id == class_id]
plt.scatter(x=this_class_id[x],y=this_class_id[y])
def trans_date(x):
str_x=str(x)
year=int(str_x[:4])
month=int(str_x[4:])
return datetime.date(year,month,1)
train_class['_sale_date']=train_class['sale_date'].apply(trans_date)
testA_df['_sale_date']=testA_df['predict_date'].apply(trans_date)
testB_df['_sale_date']=testB_df['predict_date'].apply(trans_date)
#print(train_class.head(),testA_df.head(),testB_df.head())
s="predict_date,class_id,predict_quantity"
ex=[]
f=open("../data/yancheng_testA_20171225.csv","r")
for line in f.readlines():
if "date" in line:
continue
class_id=int(line.split(",")[1])
this_class_id=train_class[train_class['class_id']==class_id][['_sale_date','sale_quantity']]
if class_id==653436:
print(this_class_id._sale_date)
#indexed_this_class_id = this_class_id.set_index(this_class_id['_sale_date'])
indexed_this_class_id=this_class_id.set_index(pd.date_range(end='2017-11',periods=len(this_class_id['_sale_date']),freq='M'))
print(this_class_id['_sale_date'],pd.date_range(end='2017-11',periods=len(this_class_id['_sale_date']),freq='M'))
res=0
try:
fit1=SARIMAX(indexed_this_class_id.sale_quantity,verbose=False).fit()
pre=fit1.get_forecast().conf_int()
res=(int(round((pre['lower sale_quantity'] + pre['upper sale_quantity'])*0.5)))
except Exception as e:
print(e)
ex.append(class_id)
plt_class(train_class,'sale_date','sale_quantity',class_id)
res=int(this_class_id['sale_quantity'].iloc[-1])
this_class_id.to_csv('EXCEPTION'+str(class_id) +".csv",header=True,index=False,float_format='%.0f')
s+="\n"
s+="201711"+ ","+str(class_id) + "," +str(res)
f.close()
s+="\n"
train_class[['sale_date','class_id','sale_quantity']].to_csv('train_class.csv',header=True,index=False,float_format='%.0f')
fw=open("201711.csv","w")
fw.write(s)
fw.close()
print(ex)