数据分析入门系列教程-股票走势预测分析

简介: 数据分析入门系列教程-股票走势预测分析

今天我们做一个关于股票的小项目--预测股票走势。首先要声明下,股市有风险,购买需谨慎啊!股票作为金融体系的一员,其走势收到了多方面的影响,并不是能够通过一两个算法,一些参数就可以完美预测,这是基于此,才衍生出了进入量化这个学科,专门用来做金融方面的数据分析。

而我们今天要做的小项目其实是最为初级的实战,就是通过线性回归算法,来简单预测未来股票的走势情况。

线性回归,这个最为基础的入门机器学习算法,它实在是太普遍,太简单了,甚至我们很多人在接触这个算法的时候,根本没有想到它竟然还是机器学习的一门算法。那么今天我们就简单的介绍下这个算法,并把它应用到股票预测上来。

线性回归

线性回归,英文是 liner regression。是利用线性回归方程的最小二乘函数对一个或多个自变量和因变量之间的关系进行建模的方法。


在上面的图表中可以看出,每一个 X 值,都会对应一个或几个 Y 值,而这些 Y 值并不是完全无规律的,它们会大致分布在一条直线的两边,对于这样的数据点,我们就可以使用线性回归来拟合预测。

数学表达

例如在空间中,我们拥有如下的数据点


其中 Y 是我们需要预测的数值,而 X 则是这些样本点的多个特征。

又因为我们假设 Y 和 X 之间是有线性关系的,所以就可以得到一个线性方程


矩阵表达


对于线性回归算法,就不再过多介绍了,感兴趣的同学可以去查找下资料,看看它的损失函数是如何推导的。

股票预测

现在我们就通过线性回归算法,来进行股票的预测。

获取数据

首先我们先导入相关库

import pandas as pd 
import tushare as ts

接下来我们就可以通过 tushare 库来获取股票数据了

import tushare as ts
df = ts.get_hist_data('000001')
print(df)
>>>
             open   high  close    low      volume  price_change  p_change  \
date                                                                         
2019-11-05  16.80  17.44  17.15  16.79  1172485.88          0.23      1.36   
2019-11-04  16.98  17.25  16.92  16.77   889824.25          0.06      0.36   
2019-11-01  16.35  17.00  16.86  16.28  1254655.50          0.60      3.69   
2019-10-31  16.42  16.47  16.26  16.24   862569.25         -0.17     -1.03   
2019-10-30  16.80  16.96  16.43  16.42   854317.19         -0.48     -2.84   
...           ...    ...    ...    ...         ...           ...       ...   
2017-05-12   8.68   8.90   8.90   8.64   917968.19          0.20      2.30   
2017-05-11   8.65   8.72   8.70   8.60   503643.56          0.03      0.35   
2017-05-10   8.63   8.80   8.67   8.62   573077.69          0.03      0.35   
2017-05-09   8.56   8.64   8.64   8.55   324194.47          0.07      0.82   
2017-05-08   8.60   8.62   8.57   8.54   460089.88         -0.06     -0.69                  ma5    ma10    ma20       v_ma5      v_ma10      v_ma20  
date                                                                    
2019-11-05  16.724  16.739  16.720  1006770.41   966103.30  1106131.26  
2019-11-04  16.676  16.666  16.673   930046.19  1037207.85  1116840.78  
2019-11-01  16.624  16.663  16.606   966107.56  1042820.65  1124667.73  
2019-10-31  16.628  16.628  16.558   853463.61  1037035.01  1112320.27  
2019-10-30  16.750  16.672  16.531   902792.71  1063396.31  1153521.44  
...            ...     ...     ...         ...         ...         ...  
2017-05-12   8.696   8.696   8.696   555794.76   555794.76   555794.76  
2017-05-11   8.645   8.645   8.645   465251.40   465251.40   465251.40  
2017-05-10   8.627   8.627   8.627   452454.01   452454.01   452454.01  
2017-05-09   8.605   8.605   8.605   392142.18   392142.18   392142.18  
2017-05-08   8.570   8.570   8.570   460089.88   460089.88   460089.88  [611 rows x 13 columns]

可以看到,此时的变量 df 中已经保存了 000001 股票的数据,时间范围是从2017.05 到2019.11

查看特征

接下来我们看下数据各个列的含义

股票数据的特征

  • date:日期
  • open:开盘价
  • high:最高价
  • close:收盘价
  • low:最低价
  • volume:成交量
  • price_change:价格变动
  • p_change:涨跌幅
  • ma5:5日均价
  • ma10:10日均价
  • ma20:20日均价
  • v_ma5:5日均量
  • v_ma10:10日均量
  • v_ma20:20日均量

数据处理

检查是否有缺失值

df.dropna(axis=0 , inplace=True)
df.isna().sum()
>>>
open            0
high            0
close           0
low             0
volume          0
price_change    0
p_change        0
ma5             0
ma10            0
ma20            0
v_ma5           0
v_ma10          0
v_ma20          0
dtype: int64

再把数据按照时间排序

df.sort_values(by=['date'], inplace=True, ascending=True)
df.tail()
>>>
             open   high  close    low      volume  price_change  p_change  \
date                                                                         
2019-10-30  16.80  16.96  16.43  16.42   854317.19         -0.48     -2.84   
2019-10-31  16.42  16.47  16.26  16.24   862569.25         -0.17     -1.03   
2019-11-01  16.35  17.00  16.86  16.28  1254655.50          0.60      3.69   
2019-11-04  16.98  17.25  16.92  16.77   889824.25          0.06      0.36   
2019-11-05  16.80  17.44  17.15  16.79  1172485.88          0.23      1.36                  ma5    ma10    ma20       v_ma5      v_ma10      v_ma20  
date                                                                    
2019-10-30  16.750  16.672  16.531   902792.71  1063396.31  1153521.44  
2019-10-31  16.628  16.628  16.558   853463.61  1037035.01  1112320.27  
2019-11-01  16.624  16.663  16.606   966107.56  1042820.65  1124667.73  
2019-11-04  16.676  16.666  16.673   930046.19  1037207.85  1116840.78  
2019-11-05  16.724  16.739  16.720  1006770.41   966103.30  1106131.26

画K线图

下面我们手动画一个K线图,来整体看下该支股票的走势

from plotly import tools
from plotly.graph_objs import *
from plotly.offline import init_notebook_mode, iplot, iplot_mpl
init_notebook_mode()
import plotly.plotly as py
import plotly.graph_objs as gotrace = go.Ohlc(x=df.index, open=df['open'], high=df['high'], low=df['low'], close=df['close'])
data = [trace]
iplot(data, filename='simple_ohlc')


我们来简单看下 K 线图该如何查看,先把图片放大,使我们可以看到每一天的具体情况


我们就以红框中的两天为例

绿色的代表当天使涨的,红色的代表当天使跌的;

而竖线的上下两端,就代表当天的最高价和最低价;

还有向左和向右的两条横线,向左的代表当天的开盘价,向右的代表当天的收盘价。所以绿色的线,向左的横线是低于向右的横线的,代表当天涨,红线则正好相反;

我们可以从K线图中看出,股票在总体上是没有线性规律的,但是在某几天之内,还是会有大致的线性规律的,所以我们可以通过线性回归预测未来某几天的股票走势,而不是未来某几个月甚至某几年的走势。

整理预测值

由于我们需要对训练好的模型做预测来确认模型的性能,所以需要把数据中的 close 数据做位置转移。

什么意思呢,比如说我要预测未来5天的股票走势,那么如果要预测的时间是2019-11-05,2019-11-04,2019-11-01,2019-10-31,2019-10-30,则需要在2019-10-29预测2019-11-05的股票走势,2019-10-29预测2019-11-04的走势,依次类推。所以就需要在2019-10-29、2019-10-28增加一列,可以记为 label,其数值就是2019-11-05、2019-11-04的 close 值。

num = 5 # 预测5天后的情况
df['label'] = df['close'].shift(-num) # 预测值

这里可能有点绕,没有理解的再慢慢体会下。

df.head(20)


观察下数据,每一行的 label 确实是未来第五天的 close 值。

提取特征

因为价格变化和交易量都是可以通过数据中的其他值计算出来的,所以和 label 一起排除在训练特征之外

feature = df.drop(['label', 'price_change', 'p_change'],axis=1)
print(feature.head())
>>>
            open  high  close   low     volume    ma5   ma10   ma20  \
date                                                                  
2017-05-08  8.60  8.62   8.57  8.54  460089.88  8.570  8.570  8.570   
2017-05-09  8.56  8.64   8.64  8.55  324194.47  8.605  8.605  8.605   
2017-05-10  8.63  8.80   8.67  8.62  573077.69  8.627  8.627  8.627   
2017-05-11  8.65  8.72   8.70  8.60  503643.56  8.645  8.645  8.645   
2017-05-12  8.68  8.90   8.90  8.64  917968.19  8.696  8.696  8.696                   v_ma5     v_ma10     v_ma20  
date                                         
2017-05-08  460089.88  460089.88  460089.88  
2017-05-09  392142.18  392142.18  392142.18  
2017-05-10  452454.01  452454.01  452454.01  
2017-05-11  465251.40  465251.40  465251.40  
2017-05-12  555794.76  555794.76  555794.76

数据规范化

from sklearn.linear_model import LinearRegression
from sklearn import preprocessing
X = feature.values
X = preprocessing.scale(X)
X = X[:-num]df.dropna(inplace=True)
Target = df.label
y = Target.valuesprint(np.shape(X), np.shape(y))
>>>
(606, 11) (606,)

构建模型

划分训练集和测试集

# 将数据分为训练数据和测试数据
X_train, y_train = X[0:550, :], y[0:550]
X_test, y_test = X[550:, -51:], y[550:606]
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
>>>
(550, 11)
(550,)
(56, 11)
(56,)

训练模型

lr = LinearRegression()
lr.fit(X_train, y_train)
lr.score(X_test, y_test)
>>>
0.5549047367028686

可以看到模型的得分并不是很高,这样很正常,毕竟股票走势不可能简单的通过一个线性回归模型就能够准确预测的

预测

X_Predict = X[-num:]
Forecast = lr.predict(X_Predict)
print(Forecast)
print(y[-num:])
print(X_Predict)
>>>
[16.27972269 16.59639261 16.6260728  16.45388175 16.64749987]
[16.43 16.26 16.86 16.92 17.15]
[[ 2.32368696  2.35103674  2.37083859  2.4030061  -0.25218353  2.49196006
   2.61527982  2.44366571  0.22559643  0.22221994  0.69169444]
 [ 2.41461999  2.51367369  2.58229797  2.51131585 -0.01723373  2.50932124
   2.64797108  2.49788649  0.218406    0.27564109  0.59120658]
 [ 2.55607137  2.49396012  2.58733272  2.58352236 -0.7100412   2.54710735
   2.65160345  2.53922312  0.00424489  0.10942186  0.33793334]
 [ 2.65710807  2.53338726  2.47656828  2.53194628 -0.08204654  2.52361869
   2.62254455  2.57358085  0.05686897 -0.0040717   0.29472561]
 [ 2.51060486  2.47917494  2.60243696  2.57320714 -0.54847362  2.57365975
   2.60853401  2.62028587 -0.40702501 -0.0257436   0.25059866]]

下面我们就可以把真实的股票走势和我们预测的股票走势都通过折线的方式画出来,这样会更加直观的看出预测结果

我们先确定要预测的5天范围

trange = pd.date_range('2019-05-13', periods=num, freq='d')
Predict_df = pd.DataFrame(Forecast, index=trange)
Predict_df.columns = ['forecast']

再把预测的5天的数值插入到原始数据当中

# 将预测值添加到原始dataframe
df_new = ts.get_hist_data('000001')
# 按照时间升序排列
df_new.sort_values(by=['date'], inplace=True, ascending=True)
df_new.index = df_new.index.astype('datetime64[ns]')
df_concat = pd.concat([df_new, Predict_df], axis=1)df_concat = df_concat[df_concat.index.isin(Predict_df.index)]
df_concat


最后画出两条折线

# 画预测值和实际值
df_concat['close'].plot(color='green', linewidth=1)
df_concat['forecast'].plot(color='orange', linewidth=3)
plt.xlabel('Time')
plt.ylabel('Price')
plt.show()


可以看出,预测的股票数值与原始真实的数值是有较大差距的,但是股票的走势还是大致相近的,那么这也可以在一定程度上指导我们是买进还是卖出了!

完整代码,在 GitHub 上下载

https://github.com/zhouwei713/DataAnalyse/tree/master/stock_prediction

总结

本节我们简单介绍了线性回归算法,作为我们从小就接触过的算法,在机器学习领域也是有着不错的应用前景的。

我们就通过线性回归模型,训练了一个简单的股票预测程序,虽然准确率不高,但是总体走势还是有一定的参考性的。当然,还是那句话,股票走势是一个很复杂的事物,其会受到方方面面各种因素的影响,要想做好金融量化,股票预测等事情,需要付出更多的努力有技术!

相关文章
|
5天前
|
机器学习/深度学习 数据采集 数据可视化
Python数据分析入门:基础知识与必备工具
【4月更文挑战第12天】Python是大数据时代数据分析的热门语言,以其简单易学和丰富库资源备受青睐。本文介绍了Python数据分析基础,包括Python语言特点、数据分析概念及其优势。重点讲解了NumPy、Pandas、Matplotlib、Seaborn和Scikit-learn等必备工具,它们分别用于数值计算、数据处理、可视化和机器学习。此外,还概述了数据分析基本流程,从数据获取到结果展示。掌握这些知识和工具,有助于初学者快速入门Python数据分析。
|
5天前
|
数据采集 存储 数据可视化
Python数据分析从入门到实践
Python数据分析从入门到实践
|
5天前
|
机器学习/深度学习 数据可视化 数据挖掘
利用Python进行数据分析与可视化:从入门到精通
本文将介绍如何使用Python语言进行数据分析与可视化,从基础概念到高级技巧一应俱全。通过学习本文,读者将掌握Python在数据处理、分析和可视化方面的核心技能,为实际项目应用打下坚实基础。
|
5天前
|
机器学习/深度学习 数据可视化 数据挖掘
Python数据分析:从入门到实践
Python数据分析:从入门到实践
|
5天前
|
存储 数据挖掘 索引
Python 教程之 Pandas(14)—— 使用 Pandas 进行数据分析
Python 教程之 Pandas(14)—— 使用 Pandas 进行数据分析
28 0
Python 教程之 Pandas(14)—— 使用 Pandas 进行数据分析
|
5天前
|
存储 机器学习/深度学习 数据挖掘
提升数据分析效率:Amazon S3 Express One Zone数据湖实战教程
提升数据分析效率:Amazon S3 Express One Zone数据湖实战教程
91 1
|
5天前
|
存储 数据挖掘 Python
借助 PyPDF2 库把数据分析系列教程文章制作成了PDF电子书,欢迎来领取!
借助 PyPDF2 库把数据分析系列教程文章制作成了PDF电子书,欢迎来领取!
|
5天前
|
数据采集 算法 数据可视化
数据分析入门系列教程-EM实战-划分LOL英雄
数据分析入门系列教程-EM实战-划分LOL英雄
|
3天前
|
机器学习/深度学习 数据挖掘 Python
Python数据分析 | 泰坦尼克逻辑回归(下)
Python数据分析 | 泰坦尼克逻辑回归
7 1
|
3天前
|
机器学习/深度学习 数据挖掘 BI
Python数据分析 | 泰坦尼克逻辑回归(上)
Python数据分析 | 泰坦尼克逻辑回归
15 0