tensorflow2.0 回归预测广告与销量之间的关系

简介: tensorflow2.0 回归预测广告与销量之间的关系
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline


# 读入数据



data = pd.read_csv('./Advertising.csv')
data

200 rows × 5 columns


### 数据可视化



plt.scatter(data.TV,data.sales) #TV跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4bfca708>
plt.scatter(data.TV,data.sales) #TV跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4bfca708>

plt.scatter(data.radio,data.sales) #radio跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4c025c48>

plt.scatter(data.newspaper,data.sales) #newspaper跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4c0a11c8>


# 数据获取



x = data.iloc[:,1:-1] #去掉第一行和最后一行
y = data.iloc[:,-1] #去掉最后一列


# 搭建网络



model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(200, input_shape=(3,), activation='relu'))
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(50, activation='relu'),
tf.keras.layers.Dense(25, activation='relu'),
(<tensorflow.python.keras.layers.core.Dense at 0x1db4bfd2d88>,)


# 设置优化模型、损失函数



model.compile(optimizer='adam', loss='mse')


# 模型训练,迭代次数为3000


model.fit(x,y,epochs=3000)



# 预测


test = data.iloc[:10, 1:-1]
model.predict(test)

array([[ 0.       , 20.463024 , 20.463192 , ..., 20.463135 ,  0.       ,
        20.463041 ],
       [ 0.       , 12.255327 , 12.255281 , ..., 12.255297 ,  0.       ,
        12.255323 ],
       [ 0.       , 12.220794 , 12.220801 , ..., 12.2207985,  0.       ,
        12.220795 ],
       ...,
       [ 0.       , 12.077147 , 12.077019 , ..., 12.077063 ,  0.       ,
        12.077131 ],
       [ 0.       ,  3.7235494,  3.7232306, ...,  3.7233384,  0.       ,
         3.7235146],
       [ 0.       , 12.560595 , 12.5604925, ..., 12.560528 ,  0.       ,
        12.560583 ]], dtype=float32)
# 对TV=3000, radio=4000, newspaper=8000的产品预测
newdata = np.array([[3000, 4000, 8000]])
newdata = tf.convert_to_tensor(newdata)
model.predict(newdata)

array([[  0.     , 879.5799 , 879.6151 , 879.55835, 879.5527 , 879.5728 ,
        879.5859 , 879.6359 , 879.5362 ,   0.     , 879.6159 , 879.60004,
        879.5276 , 879.5824 ,   0.     , 879.57623, 879.5726 , 879.52856,
        879.5319 , 879.5582 , 879.58704,   0.     , 879.57697, 879.5325 ,
          0.     , 879.59314,   0.     , 879.53687, 879.5954 , 879.5831 ,
        879.60297, 879.58563, 879.5935 , 879.64575, 879.58405, 879.5719 ,
        879.602  ,   0.     , 879.57135,   0.     ,   0.     , 879.5816 ,
          0.     , 879.5819 , 879.588  ,   0.     , 879.6433 , 879.5492 ,
        879.6192 , 879.54034,   0.     ,   0.     , 879.5796 ,   0.     ,
        879.61273, 879.58875, 879.6085 , 879.55133, 879.5839 , 879.58734,
        879.5705 , 879.5794 , 879.56946, 879.5359 , 879.58   , 879.558  ,
        879.5586 ,   0.     , 879.53674, 879.6124 , 879.5777 ,   0.     ,
        879.57715, 879.5843 , 879.54456, 879.61444, 879.6114 , 879.564  ,
        879.5577 , 879.6062 ,   0.     , 879.5439 , 879.579  , 879.58795,
          0.     ,   0.     , 879.59534, 879.59265, 879.5501 , 879.5469 ,
        879.63715, 879.588  , 879.58875, 879.61774, 879.5832 , 879.5349 ,
        879.6245 , 879.61707, 879.63776, 879.52814, 879.541  , 879.6059 ,
        879.5777 , 879.57605, 879.5842 , 879.60376, 879.57263, 879.58453,
        879.5369 , 879.526  , 879.56195, 879.5804 ,   0.     , 879.5381 ,
        879.6174 , 879.6216 ,   0.     , 879.58386, 879.5861 , 879.5596 ,
        879.5752 ,   0.     , 879.584  , 879.5409 ,   0.     , 879.5707 ,
        879.5675 , 879.62164, 879.60065, 879.57086, 879.5787 ,   0.     ,
        879.57446, 879.55896, 879.6028 , 879.5937 , 879.6015 , 879.565  ,
        879.57623, 879.6132 ,   0.     , 879.6237 ,   0.     , 879.59235,
        879.58936, 879.53656, 879.6577 , 879.5495 , 879.5707 , 879.6216 ,
        879.5622 ,   0.     , 879.6187 ,   0.     , 879.5673 , 879.5593 ,
        879.6082 , 879.5289 , 879.56604, 879.61633, 879.57294, 879.58105,
        879.5859 , 879.6015 ,   0.     , 879.5849 , 879.6053 , 879.59924,
        879.6134 ,   0.     , 879.6034 , 879.5475 , 879.57996,   0.     ,
        879.55774, 879.57294,   0.     , 879.57043, 879.62555, 879.6166 ,
        879.5759 , 879.5498 , 879.53174, 879.541  , 879.58746, 879.57214,
        879.5436 , 879.5625 , 879.5702 , 879.62506, 879.61346,   0.     ,
        879.5801 , 879.62103, 879.55475, 879.6133 , 879.6022 , 879.6031 ,
          0.     , 879.5836 ]], dtype=float32)

若有收获,就点个赞吧

相关文章
|
机器学习/深度学习 供应链 安全
TSMixer:谷歌发布的用于时间序列预测的全新全mlp架构
这是谷歌在9月最近发布的一种新的架构 TSMixer: An all-MLP architecture for time series forecasting ,TSMixer是一种先进的多元模型,利用线性模型特征,在长期预测基准上表现良好。据我们所知,TSMixer是第一个在长期预测基准上表现与最先进的单变量模型一样好的多变量模型,在长期预测基准上,表明交叉变量信息不太有益。”
302 1
|
4月前
|
机器学习/深度学习 存储 数据可视化
谷歌的时间序列预测的基础模型TimesFM详解和对比测试
在本文中,我们将介绍模型架构、训练,并进行实际预测案例研究。将对TimesFM的预测能力进行分析,并将该模型与统计和机器学习模型进行对比。
162 2
|
6月前
|
机器学习/深度学习 数据采集 人工智能
SPSS modeler利用类神经网络对茅台股价涨跌幅度进行预测
SPSS modeler利用类神经网络对茅台股价涨跌幅度进行预测
|
6月前
|
机器学习/深度学习 数据可视化 搜索推荐
PYTHON条件生存森林模型CONDITIONAL SURVIVAL FOREST分类预测客户流失交叉验证可视化|数据分享
PYTHON条件生存森林模型CONDITIONAL SURVIVAL FOREST分类预测客户流失交叉验证可视化|数据分享
|
6月前
|
机器学习/深度学习 数据可视化 算法
R语言独立成分分析fastICA、谱聚类、支持向量回归SVR模型预测商店销量时间序列可视化
R语言独立成分分析fastICA、谱聚类、支持向量回归SVR模型预测商店销量时间序列可视化
|
6月前
|
自然语言处理 JavaScript 数据可视化
数据代码分享|R语言基于逐步多元回归模型的天猫商品流行度预测
数据代码分享|R语言基于逐步多元回归模型的天猫商品流行度预测
|
6月前
|
安全 vr&ar
R语言非线性动态回归模型ARIMAX、随机、确定性趋势时间序列预测个人消费和收入、用电量、国际游客数量
R语言非线性动态回归模型ARIMAX、随机、确定性趋势时间序列预测个人消费和收入、用电量、国际游客数量
|
6月前
|
数据采集 机器学习/深度学习 供应链
python基于评论情感分析和回归、arima销量预测的购物网站选品
python基于评论情感分析和回归、arima销量预测的购物网站选品
|
6月前
|
算法 数据挖掘 关系型数据库
有限混合模型聚类FMM、广义线性回归模型GLM混合应用分析威士忌市场和研究专利申请数据
有限混合模型聚类FMM、广义线性回归模型GLM混合应用分析威士忌市场和研究专利申请数据
|
机器学习/深度学习 PyTorch 算法框架/工具
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
226 0
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战