使用TensorFlow2.0的 Keras实现线性回归 训练模型
Keras实现单变量线性回归;使用场景:根据工作小时得出报酬,使用的Anaconda 进行的操作
Anaconda下载地址
计算公式 其中x:代表工作显示数 f(x) :代表工作报仇 a和b:通过 *梯度下降算法* 计算出来的值; 梯度下降算法:是线性回归的核心算法 f(x)=xa+b
import tensorflow as tf • 1
print("Tf V{}".format(tf.__version__))
Tf V2.4.1 • 1
#[pandas中文网站](https://www.pypandas.cn) import pandas as pd
#引入图表库 import matplotlib.pyplot as plt %matplotlib inline • 1 • 2 • 3
data=pd.read_csv("./work_hours.csv")
data • 1
#生成线性图表 plt.scatter(data.word,data.money) • 1 • 2
<matplotlib.collections.PathCollection at 0x21fb71726a0>
# 图中的 变量 x=data.word y=data.money • 1 • 2 • 3
#初始化 顺序模型 model=tf.keras.Sequential()
#添加层 Dense(维度,) model.add(tf.keras.layers.Dense(1,input_shape=(1,))) • 1 • 2
#显示模型层 model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 1) 2 ================================================================= Total params: 2 Trainable params: 2 Non-trainable params: 0 _________________________________________________________________
#编译、配置 optimizer:优化方法 名 #loss :损失值 model.compile(optimizer="adam", loss="mse" )
#训练 epochs:训练次数 (应该是训练次数越多 月稳定) history=model.fit(x,y,epochs=50000)
# x:表里面的 word(工作小时数) 那一行的值 model.predict(x)
array([[ 53.363663], [ 60.72731 ], [ 68.09096 ], [ 75.454605], [ 82.81825 ], [ 90.18191 ], [ 97.545555], [104.9092 ], [112.27285 ], [119.6365 ]], dtype=float32)
#假设工作4小时 model.predict(pd.Series([4])) • 1 • 2
array([[75.454605]], dtype=float32)
#假设工作14小时 model.predict(pd.Series([14])) • 1 • 2
array([[149.0911]], dtype=float32)
#假设工作8小时 model.predict(pd.Series([8])) • 1 • 2
array([[104.9092]], dtype=float32)
#假设工作40小时 model.predict(pd.Series([40])) • 1 • 2
array([[340.54596]], dtype=float32)
#假设工作24小时 model.predict(pd.Series([24])) • 1 • 2
array([[222.72758]], dtype=float32)
#假设工作1小时 model.predict(pd.Series([1])) • 1 • 2
array([[53.363663]], dtype=float32)
#假设工作3小时 model.predict(pd.Series([3])) • 1 • 2
array([[68.09096]], dtype=float32)
使用到的文件格式
一定要使用的csv格式 的 文件,创建csv格式文件可以去百度搜索