tensorflow2.0 回归预测广告与销量之间的关系

简介: tensorflow2.0 回归预测广告与销量之间的关系
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline


# 读入数据



data = pd.read_csv('./Advertising.csv')
data

200 rows × 5 columns


### 数据可视化



plt.scatter(data.TV,data.sales) #TV跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4bfca708>
plt.scatter(data.TV,data.sales) #TV跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4bfca708>

plt.scatter(data.radio,data.sales) #radio跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4c025c48>

plt.scatter(data.newspaper,data.sales) #newspaper跟sales之间的线性关系
<matplotlib.collections.PathCollection at 0x1db4c0a11c8>


# 数据获取



x = data.iloc[:,1:-1] #去掉第一行和最后一行
y = data.iloc[:,-1] #去掉最后一列


# 搭建网络



model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(200, input_shape=(3,), activation='relu'))
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(50, activation='relu'),
tf.keras.layers.Dense(25, activation='relu'),
(<tensorflow.python.keras.layers.core.Dense at 0x1db4bfd2d88>,)


# 设置优化模型、损失函数



model.compile(optimizer='adam', loss='mse')


# 模型训练,迭代次数为3000


model.fit(x,y,epochs=3000)



# 预测


test = data.iloc[:10, 1:-1]
model.predict(test)

array([[ 0.       , 20.463024 , 20.463192 , ..., 20.463135 ,  0.       ,
        20.463041 ],
       [ 0.       , 12.255327 , 12.255281 , ..., 12.255297 ,  0.       ,
        12.255323 ],
       [ 0.       , 12.220794 , 12.220801 , ..., 12.2207985,  0.       ,
        12.220795 ],
       ...,
       [ 0.       , 12.077147 , 12.077019 , ..., 12.077063 ,  0.       ,
        12.077131 ],
       [ 0.       ,  3.7235494,  3.7232306, ...,  3.7233384,  0.       ,
         3.7235146],
       [ 0.       , 12.560595 , 12.5604925, ..., 12.560528 ,  0.       ,
        12.560583 ]], dtype=float32)
# 对TV=3000, radio=4000, newspaper=8000的产品预测
newdata = np.array([[3000, 4000, 8000]])
newdata = tf.convert_to_tensor(newdata)
model.predict(newdata)

array([[  0.     , 879.5799 , 879.6151 , 879.55835, 879.5527 , 879.5728 ,
        879.5859 , 879.6359 , 879.5362 ,   0.     , 879.6159 , 879.60004,
        879.5276 , 879.5824 ,   0.     , 879.57623, 879.5726 , 879.52856,
        879.5319 , 879.5582 , 879.58704,   0.     , 879.57697, 879.5325 ,
          0.     , 879.59314,   0.     , 879.53687, 879.5954 , 879.5831 ,
        879.60297, 879.58563, 879.5935 , 879.64575, 879.58405, 879.5719 ,
        879.602  ,   0.     , 879.57135,   0.     ,   0.     , 879.5816 ,
          0.     , 879.5819 , 879.588  ,   0.     , 879.6433 , 879.5492 ,
        879.6192 , 879.54034,   0.     ,   0.     , 879.5796 ,   0.     ,
        879.61273, 879.58875, 879.6085 , 879.55133, 879.5839 , 879.58734,
        879.5705 , 879.5794 , 879.56946, 879.5359 , 879.58   , 879.558  ,
        879.5586 ,   0.     , 879.53674, 879.6124 , 879.5777 ,   0.     ,
        879.57715, 879.5843 , 879.54456, 879.61444, 879.6114 , 879.564  ,
        879.5577 , 879.6062 ,   0.     , 879.5439 , 879.579  , 879.58795,
          0.     ,   0.     , 879.59534, 879.59265, 879.5501 , 879.5469 ,
        879.63715, 879.588  , 879.58875, 879.61774, 879.5832 , 879.5349 ,
        879.6245 , 879.61707, 879.63776, 879.52814, 879.541  , 879.6059 ,
        879.5777 , 879.57605, 879.5842 , 879.60376, 879.57263, 879.58453,
        879.5369 , 879.526  , 879.56195, 879.5804 ,   0.     , 879.5381 ,
        879.6174 , 879.6216 ,   0.     , 879.58386, 879.5861 , 879.5596 ,
        879.5752 ,   0.     , 879.584  , 879.5409 ,   0.     , 879.5707 ,
        879.5675 , 879.62164, 879.60065, 879.57086, 879.5787 ,   0.     ,
        879.57446, 879.55896, 879.6028 , 879.5937 , 879.6015 , 879.565  ,
        879.57623, 879.6132 ,   0.     , 879.6237 ,   0.     , 879.59235,
        879.58936, 879.53656, 879.6577 , 879.5495 , 879.5707 , 879.6216 ,
        879.5622 ,   0.     , 879.6187 ,   0.     , 879.5673 , 879.5593 ,
        879.6082 , 879.5289 , 879.56604, 879.61633, 879.57294, 879.58105,
        879.5859 , 879.6015 ,   0.     , 879.5849 , 879.6053 , 879.59924,
        879.6134 ,   0.     , 879.6034 , 879.5475 , 879.57996,   0.     ,
        879.55774, 879.57294,   0.     , 879.57043, 879.62555, 879.6166 ,
        879.5759 , 879.5498 , 879.53174, 879.541  , 879.58746, 879.57214,
        879.5436 , 879.5625 , 879.5702 , 879.62506, 879.61346,   0.     ,
        879.5801 , 879.62103, 879.55475, 879.6133 , 879.6022 , 879.6031 ,
          0.     , 879.5836 ]], dtype=float32)

若有收获,就点个赞吧

相关文章
|
4月前
|
机器学习/深度学习 数据挖掘 Python
简单几步,教你使用scikit-learn做分类和回归预测
简单几步,教你使用scikit-learn做分类和回归预测
|
7月前
|
机器学习/深度学习 数据可视化 搜索推荐
PYTHON条件生存森林模型CONDITIONAL SURVIVAL FOREST分类预测客户流失交叉验证可视化|数据分享
PYTHON条件生存森林模型CONDITIONAL SURVIVAL FOREST分类预测客户流失交叉验证可视化|数据分享
|
7月前
|
机器学习/深度学习 数据可视化 算法
R语言独立成分分析fastICA、谱聚类、支持向量回归SVR模型预测商店销量时间序列可视化
R语言独立成分分析fastICA、谱聚类、支持向量回归SVR模型预测商店销量时间序列可视化
|
7月前
|
数据可视化
R语言Copula模型分析股票市场板块相关性结构
R语言Copula模型分析股票市场板块相关性结构
|
7月前
|
数据可视化 索引 Python
数据分享|Python用PyMC3贝叶斯模型平均BMA:采样、信息准则比较和预测可视化灵长类动物的乳汁成分数据
数据分享|Python用PyMC3贝叶斯模型平均BMA:采样、信息准则比较和预测可视化灵长类动物的乳汁成分数据
|
7月前
|
数据采集 机器学习/深度学习 供应链
python基于评论情感分析和回归、arima销量预测的购物网站选品
python基于评论情感分析和回归、arima销量预测的购物网站选品
|
7月前
|
机器学习/深度学习 数据可视化
数据分享|R语言用RFM、决策树模型顾客购书行为的数据预测
数据分享|R语言用RFM、决策树模型顾客购书行为的数据预测
|
7月前
|
数据可视化 数据挖掘
R语言混合线性模型、多层次模型、回归模型分析学生平均成绩GPA和可视化
R语言混合线性模型、多层次模型、回归模型分析学生平均成绩GPA和可视化
|
机器学习/深度学习 PyTorch 算法框架/工具
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
243 0
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
|
机器学习/深度学习 存储 自然语言处理
基于图卷积神经网络GCN的时间序列预测:图与递归结构相结合的库存品需求预测
基于图卷积神经网络GCN的时间序列预测:图与递归结构相结合的库存品需求预测
476 0
基于图卷积神经网络GCN的时间序列预测:图与递归结构相结合的库存品需求预测