import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt import numpy as np
%matplotlib inline
# 读入数据
data = pd.read_csv('./Advertising.csv')
data
200 rows × 5 columns
### 数据可视化
plt.scatter(data.TV,data.sales) #TV跟sales之间的线性关系 <matplotlib.collections.PathCollection at 0x1db4bfca708>
plt.scatter(data.TV,data.sales) #TV跟sales之间的线性关系 <matplotlib.collections.PathCollection at 0x1db4bfca708>
plt.scatter(data.radio,data.sales) #radio跟sales之间的线性关系 <matplotlib.collections.PathCollection at 0x1db4c025c48>
plt.scatter(data.newspaper,data.sales) #newspaper跟sales之间的线性关系 <matplotlib.collections.PathCollection at 0x1db4c0a11c8>
# 数据获取
x = data.iloc[:,1:-1] #去掉第一行和最后一行 y = data.iloc[:,-1] #去掉最后一列
# 搭建网络
model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(200, input_shape=(3,), activation='relu')) tf.keras.layers.Dense(100, activation='relu'), tf.keras.layers.Dense(50, activation='relu'), tf.keras.layers.Dense(25, activation='relu'), (<tensorflow.python.keras.layers.core.Dense at 0x1db4bfd2d88>,)
# 设置优化模型、损失函数
model.compile(optimizer='adam', loss='mse')
# 模型训练,迭代次数为3000
model.fit(x,y,epochs=3000)
# 预测
test = data.iloc[:10, 1:-1] model.predict(test)
array([[ 0. , 20.463024 , 20.463192 , ..., 20.463135 , 0. , 20.463041 ], [ 0. , 12.255327 , 12.255281 , ..., 12.255297 , 0. , 12.255323 ], [ 0. , 12.220794 , 12.220801 , ..., 12.2207985, 0. , 12.220795 ], ..., [ 0. , 12.077147 , 12.077019 , ..., 12.077063 , 0. , 12.077131 ], [ 0. , 3.7235494, 3.7232306, ..., 3.7233384, 0. , 3.7235146], [ 0. , 12.560595 , 12.5604925, ..., 12.560528 , 0. , 12.560583 ]], dtype=float32) # 对TV=3000, radio=4000, newspaper=8000的产品预测 newdata = np.array([[3000, 4000, 8000]]) newdata = tf.convert_to_tensor(newdata) model.predict(newdata)
array([[ 0. , 879.5799 , 879.6151 , 879.55835, 879.5527 , 879.5728 , 879.5859 , 879.6359 , 879.5362 , 0. , 879.6159 , 879.60004, 879.5276 , 879.5824 , 0. , 879.57623, 879.5726 , 879.52856, 879.5319 , 879.5582 , 879.58704, 0. , 879.57697, 879.5325 , 0. , 879.59314, 0. , 879.53687, 879.5954 , 879.5831 , 879.60297, 879.58563, 879.5935 , 879.64575, 879.58405, 879.5719 , 879.602 , 0. , 879.57135, 0. , 0. , 879.5816 , 0. , 879.5819 , 879.588 , 0. , 879.6433 , 879.5492 , 879.6192 , 879.54034, 0. , 0. , 879.5796 , 0. , 879.61273, 879.58875, 879.6085 , 879.55133, 879.5839 , 879.58734, 879.5705 , 879.5794 , 879.56946, 879.5359 , 879.58 , 879.558 , 879.5586 , 0. , 879.53674, 879.6124 , 879.5777 , 0. , 879.57715, 879.5843 , 879.54456, 879.61444, 879.6114 , 879.564 , 879.5577 , 879.6062 , 0. , 879.5439 , 879.579 , 879.58795, 0. , 0. , 879.59534, 879.59265, 879.5501 , 879.5469 , 879.63715, 879.588 , 879.58875, 879.61774, 879.5832 , 879.5349 , 879.6245 , 879.61707, 879.63776, 879.52814, 879.541 , 879.6059 , 879.5777 , 879.57605, 879.5842 , 879.60376, 879.57263, 879.58453, 879.5369 , 879.526 , 879.56195, 879.5804 , 0. , 879.5381 , 879.6174 , 879.6216 , 0. , 879.58386, 879.5861 , 879.5596 , 879.5752 , 0. , 879.584 , 879.5409 , 0. , 879.5707 , 879.5675 , 879.62164, 879.60065, 879.57086, 879.5787 , 0. , 879.57446, 879.55896, 879.6028 , 879.5937 , 879.6015 , 879.565 , 879.57623, 879.6132 , 0. , 879.6237 , 0. , 879.59235, 879.58936, 879.53656, 879.6577 , 879.5495 , 879.5707 , 879.6216 , 879.5622 , 0. , 879.6187 , 0. , 879.5673 , 879.5593 , 879.6082 , 879.5289 , 879.56604, 879.61633, 879.57294, 879.58105, 879.5859 , 879.6015 , 0. , 879.5849 , 879.6053 , 879.59924, 879.6134 , 0. , 879.6034 , 879.5475 , 879.57996, 0. , 879.55774, 879.57294, 0. , 879.57043, 879.62555, 879.6166 , 879.5759 , 879.5498 , 879.53174, 879.541 , 879.58746, 879.57214, 879.5436 , 879.5625 , 879.5702 , 879.62506, 879.61346, 0. , 879.5801 , 879.62103, 879.55475, 879.6133 , 879.6022 , 879.6031 , 0. , 879.5836 ]], dtype=float32)
若有收获,就点个赞吧