【阿旭机器学习实战】【31】股票价格预测案例--线性回归

简介: 【阿旭机器学习实战】【31】股票价格预测案例--线性回归

1. 读取数据

import numpy as np # 数学计算
import pandas as pd # 数据处理
import matplotlib.pyplot as plt
from datetime import datetime as dt

关注公众号:阿旭算法与机器学习,回复:“ML31”即可获取本文数据集、源码与项目文档,欢迎共同学习交流

df = pd.read_csv('./000001.csv') 
print(np.shape(df))
df.head()
(611, 14)
date open high close low volume price_change p_change ma5 ma10 ma20 v_ma5 v_ma10 v_ma20
0 2019-05-30 12.32 12.38 12.22 12.11 646284.62 -0.18 -1.45 12.366 12.390 12.579 747470.29 739308.42 953969.39
1 2019-05-29 12.36 12.59 12.40 12.26 666411.50 -0.09 -0.72 12.380 12.453 12.673 751584.45 738170.10 973189.95
2 2019-05-28 12.31 12.55 12.49 12.26 880703.12 0.12 0.97 12.380 12.505 12.742 719548.29 781927.80 990340.43
3 2019-05-27 12.21 12.42 12.37 11.93 1048426.00 0.02 0.16 12.394 12.505 12.824 689649.77 812117.30 1001879.10
4 2019-05-24 12.35 12.45 12.35 12.31 495526.19 0.06 0.49 12.396 12.498 12.928 637251.61 781466.47 1046943.98

股票数据的特征

  • date:日期
  • open:开盘价
  • high:最高价
  • close:收盘价
  • low:最低价
  • volume:成交量
  • price_change:价格变动
  • p_change:涨跌幅
  • ma5:5日均价
  • ma10:10日均价
  • ma20:20日均价
  • v_ma5:5日均量
  • v_ma10:10日均量
  • v_ma20:20日均量
# 将每一个数据的键值的类型从字符串转为日期
df['date'] = pd.to_datetime(df['date'])
# 将日期变为索引
df = df.set_index('date')
# 按照时间升序排列
df.sort_values(by=['date'], inplace=True, ascending=True)
df.tail()
open high close low volume price_change p_change ma5 ma10 ma20 v_ma5 v_ma10 v_ma20
date
2019-05-24 12.35 12.45 12.35 12.31 495526.19 0.06 0.49 12.396 12.498 12.928 637251.61 781466.47 1046943.98
2019-05-27 12.21 12.42 12.37 11.93 1048426.00 0.02 0.16 12.394 12.505 12.824 689649.77 812117.30 1001879.10
2019-05-28 12.31 12.55 12.49 12.26 880703.12 0.12 0.97 12.380 12.505 12.742 719548.29 781927.80 990340.43
2019-05-29 12.36 12.59 12.40 12.26 666411.50 -0.09 -0.72 12.380 12.453 12.673 751584.45 738170.10 973189.95
2019-05-30 12.32 12.38 12.22 12.11 646284.62 -0.18 -1.45 12.366 12.390 12.579 747470.29 739308.42 953969.39
# 检测是否有缺失数据 NaNs
df.dropna(axis=0 , inplace=True)
df.isna().sum()
open            0
high            0
close           0
low             0
volume          0
price_change    0
p_change        0
ma5             0
ma10            0
ma20            0
v_ma5           0
v_ma10          0
v_ma20          0
dtype: int64

K线图绘制

Min_date = df.index.min()
Max_date = df.index.max()
print ("First date is",Min_date)
print ("Last date is",Max_date)
print (Max_date - Min_date)
First date is 2016-11-29 00:00:00
Last date is 2019-05-30 00:00:00
912 days 00:00:00
from plotly import tools
from plotly.graph_objs import *
from plotly.offline import init_notebook_mode, iplot, iplot_mpl
init_notebook_mode()
import chart_studio.plotly as py
import plotly.graph_objs as go
trace = go.Ohlc(x=df.index, open=df['open'], high=df['high'], low=df['low'], close=df['close'])
data = [trace]
iplot(data, filename='simple_ohlc')

2.构建回归模型

from sklearn.linear_model import LinearRegression
from sklearn import preprocessing
# 创建标签数据:即预测值, 根据当前的数据预测5天以后的收盘价
num = 5 # 预测5天后的情况
df['label'] = df['close'].shift(-num) # 预测值,将5天后的收盘价当作当前样本的标签
                                     
print(df.shape)
(611, 14)
# 丢弃 'label', 'price_change', 'p_change', 不需要它们做预测
Data = df.drop(['label', 'price_change', 'p_change'],axis=1)
Data.tail()
open high close low volume ma5 ma10 ma20 v_ma5 v_ma10 v_ma20
date
2019-05-24 12.35 12.45 12.35 12.31 495526.19 12.396 12.498 12.928 637251.61 781466.47 1046943.98
2019-05-27 12.21 12.42 12.37 11.93 1048426.00 12.394 12.505 12.824 689649.77 812117.30 1001879.10
2019-05-28 12.31 12.55 12.49 12.26 880703.12 12.380 12.505 12.742 719548.29 781927.80 990340.43
2019-05-29 12.36 12.59 12.40 12.26 666411.50 12.380 12.453 12.673 751584.45 738170.10 973189.95
2019-05-30 12.32 12.38 12.22 12.11 646284.62 12.366 12.390 12.579 747470.29 739308.42 953969.39
X = Data.values
# 去掉最后5行,因为没有Y的值
X = X[:-num]
# 将特征进行归一化
X = preprocessing.scale(X)
# 去掉标签为null的最后5行
df.dropna(inplace=True)
Target = df.label
y = Target.values
print(np.shape(X), np.shape(y))
(606, 11) (606,)
# 将数据分为训练数据和测试数据
X_train, y_train = X[0:550, :], y[0:550]
X_test, y_test = X[550:, -51:], y[550:606]
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
(550, 11)
(550,)
(56, 11)
(56,)
lr = LinearRegression()
lr.fit(X_train, y_train)
lr.score(X_test, y_test) # 使用绝对系数 R^2 评估模型
0.04930040648385525
# 做预测 :取最后5行数据,预测5天后的股票价格
X_Predict = X[-num:]
Forecast = lr.predict(X_Predict)
print(Forecast)
print(y[-num:])
[12.5019651  12.45069629 12.56248765 12.3172638  12.27070154]
[12.35 12.37 12.49 12.4  12.22]
• 1
• 2
# 查看模型的各个特征参数的系数值
for idx, col_name in enumerate(['open', 'high', 'close', 'low', 'volume', 'ma5', 'ma10', 'ma20', 'v_ma5', 'v_ma10', 'v_ma20']):
    print("The coefficient for {} is {}".format(col_name, lr.coef_[idx]))
The coefficient for open is -0.7623399996475224
The coefficient for high is 0.8321435171405448
The coefficient for close is 0.24463705375238926
The coefficient for low is 1.091415550493547
The coefficient for volume is 0.0043807937569128675
The coefficient for ma5 is -0.30717535019465575
The coefficient for ma10 is 0.1935431079947582
The coefficient for ma20 is 0.24902077484698157
The coefficient for v_ma5 is 0.17472336466033722
The coefficient for v_ma10 is 0.08873934447969857
The coefficient for v_ma20 is -0.27910702694420775

3.绘制预测结果

# 预测 2019-05-13 到 2019-05-17 , 一共 5 天的收盘价 
trange = pd.date_range('2019-05-13', periods=num, freq='d')
trange
DatetimeIndex(['2019-05-13', '2019-05-14', '2019-05-15', '2019-05-16',
               '2019-05-17'],
              dtype='datetime64[ns]', freq='D')
# 产生预测值dataframe
Predict_df = pd.DataFrame(Forecast, index=trange)
Predict_df.columns = ['forecast']
Predict_df
forecast
2019-05-13 12.501965
2019-05-14 12.450696
2019-05-15 12.562488
2019-05-16 12.317264
2019-05-17 12.270702
# 将预测值添加到原始dataframe
df = pd.read_csv('./000001.csv') 
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
# 按照时间升序排列
df.sort_values(by=['date'], inplace=True, ascending=True)
df_concat = pd.concat([df, Predict_df], axis=1)
df_concat = df_concat[df_concat.index.isin(Predict_df.index)]
df_concat.tail(num)
open high close low volume price_change p_change ma5 ma10 ma20 v_ma5 v_ma10 v_ma20 forecast
2019-05-13 12.33 12.54 12.30 12.23 741917.75 -0.38 -3.00 12.538 13.143 13.637 1107915.51 1191640.89 1211461.61 12.501965
2019-05-14 12.20 12.75 12.49 12.16 1182598.12 0.19 1.54 12.446 12.979 13.585 1129903.46 1198753.07 1237823.69 12.450696
2019-05-15 12.58 13.11 12.92 12.57 1103988.50 0.43 3.44 12.510 12.892 13.560 1155611.00 1208209.79 1254306.88 12.562488
2019-05-16 12.93 12.99 12.85 12.78 634901.44 -0.07 -0.54 12.648 12.767 13.518 971160.96 1168630.36 1209357.42 12.317264
2019-05-17 12.92 12.93 12.44 12.36 965000.88 -0.41 -3.19 12.600 12.626 13.411 925681.34 1153473.43 1138638.70 12.270702
# 画预测值和实际值
df_concat['close'].plot(color='green', linewidth=1)
df_concat['forecast'].plot(color='orange', linewidth=3)
plt.xlabel('Time')
plt.ylabel('Price')
plt.show()


相关文章
|
8月前
|
机器学习/深度学习 存储 运维
机器学习异常检测实战:用Isolation Forest快速构建无标签异常检测系统
本研究通过实验演示了异常标记如何逐步完善异常检测方案和主要分类模型在欺诈检测中的应用。实验结果表明,Isolation Forest作为一个强大的异常检测模型,无需显式建模正常模式即可有效工作,在处理未见风险事件方面具有显著优势。
665 46
|
11月前
|
机器学习/深度学习 数据可视化 TensorFlow
Python 高级编程与实战:深入理解数据科学与机器学习
本文深入探讨了Python在数据科学与机器学习中的应用,介绍了pandas、numpy、matplotlib等数据科学工具,以及scikit-learn、tensorflow、keras等机器学习库。通过实战项目,如数据可视化和鸢尾花数据集分类,帮助读者掌握这些技术。最后提供了进一步学习资源,助力提升Python编程技能。
|
11月前
|
机器学习/深度学习 人工智能 Java
Java机器学习实战:基于DJL框架的手写数字识别全解析
在人工智能蓬勃发展的今天,Python凭借丰富的生态库(如TensorFlow、PyTorch)成为AI开发的首选语言。但Java作为企业级应用的基石,其在生产环境部署、性能优化和工程化方面的优势不容忽视。DJL(Deep Java Library)的出现完美填补了Java在深度学习领域的空白,它提供了一套统一的API,允许开发者无缝对接主流深度学习框架,将AI模型高效部署到Java生态中。本文将通过手写数字识别的完整流程,深入解析DJL框架的核心机制与应用实践。
699 3
|
11月前
|
机器学习/深度学习 数据可视化 算法
Python 高级编程与实战:深入理解数据科学与机器学习
在前几篇文章中,我们探讨了 Python 的基础语法、面向对象编程、函数式编程、元编程、性能优化和调试技巧。本文将深入探讨 Python 在数据科学和机器学习中的应用,并通过实战项目帮助你掌握这些技术。
|
12月前
|
数据可视化 API 开发者
R1类模型推理能力评测手把手实战
R1类模型推理能力评测手把手实战
371 2
|
12月前
|
人工智能 自然语言处理 网络安全
基于阿里云 Milvus + DeepSeek + PAI LangStudio 的低成本高精度 RAG 实战
阿里云向量检索服务Milvus版是一款全托管向量检索引擎,并确保与开源Milvus的完全兼容性,支持无缝迁移。它在开源版本的基础上增强了可扩展性,能提供大规模AI向量数据的相似性检索服务。凭借其开箱即用的特性、灵活的扩展能力和全链路监控告警,Milvus云服务成为多样化AI应用场景的理想选择,包括多模态搜索、检索增强生成(RAG)、搜索推荐、内容风险识别等。您还可以利用开源的Attu工具进行可视化操作,进一步促进应用的快速开发和部署。
|
12月前
|
数据可视化 API 开发者
R1类模型推理能力评测手把手实战
随着DeepSeek-R1模型的广泛应用,越来越多的开发者开始尝试复现类似的模型,以提升其推理能力。
1020 2
|
12月前
|
数据可视化 API 开发者
R1类模型推理能力评测手把手实战
随着DeepSeek-R1模型的广泛应用,越来越多的开发者开始尝试复现类似的模型,以提升其推理能力。
784 3
|
机器学习/深度学习 传感器 运维
使用机器学习技术进行时间序列缺失数据填充:基础方法与入门案例
本文探讨了时间序列分析中数据缺失的问题,并通过实际案例展示了如何利用机器学习技术进行缺失值补充。文章构建了一个模拟的能源生产数据集,采用线性回归和决策树回归两种方法进行缺失值补充,并从统计特征、自相关性、趋势和季节性等多个维度进行了详细评估。结果显示,决策树方法在处理复杂非线性模式和保持数据局部特征方面表现更佳,而线性回归方法则适用于简单的线性趋势数据。文章最后总结了两种方法的优劣,并给出了实际应用建议。
800 7
使用机器学习技术进行时间序列缺失数据填充:基础方法与入门案例
|
机器学习/深度学习 人工智能 算法
探索机器学习:从线性回归到深度学习
本文将带领读者从基础的线性回归模型开始,逐步深入到复杂的深度学习网络。我们将通过代码示例,展示如何实现这些算法,并解释其背后的数学原理。无论你是初学者还是有经验的开发者,这篇文章都将为你提供有价值的见解和知识。让我们一起踏上这段激动人心的旅程吧!
245 3

热门文章

最新文章