TensorFlow 2 Keras实现线性回归

简介: TensorFlow 2 Keras实现线性回归

介绍


线性回归是入门机器学习必学的算法,其也是最基础的算法之一。


接下来,我们以线性回归为例,使用 TensorFlow 2 提供的 API 和 Eager Execution 机制对其进行实现。


线性回归是一种较为简单,但十分重要的机器学习方法,它也是神经网络的基础。


如下所示,线性回归要解决的问题就是如何找到最理想的直线去拟合散点样本。


13.png


对于一个线性回归问题,一般来讲有 2 种解决方法,分别是:


  • 最小二乘法

代数求解

矩阵求解

  • 梯度下降法。

本次,我们将使用梯度下降方法来解决线性回归问题。


Keras 方式实现


配合 TensorFlow 提供的高阶 API,我们省去了定义线性函数,定义损失函数,以及定义优化算法等 3 个步骤。


不过,高阶 API 实现过程实际上还不够精简,我们可以完全使用 TensorFlow Keras API 来实现线性回归。


Keras 本来是一个用 Python 编写的独立高阶神经网络 API,它能够以 TensorFlow, CNTK,或者 Theano 作为后端运行。


目前,TensorFlow 已经吸纳 Keras,并组成了 tf.keras 模块。官方介绍,tf.keras 和单独安装的 Keras 略有不同,但考虑到未来的发展趋势,主要以学习 tf.keras 为主。


12.png


初始化


11.png


我们这里使用 Keras 提供的 Sequential 顺序模型结构。向其中添加一个线性层。不同的地方在于,Keras 顺序模型第一层为线性层时,规定需指定输入维度,这里为 input_dim=1。


10.png


接下来,直接使用 .compile 编译模型,指定损失函数为 MSE 平方损失函数,优化器选择 SGD 随机梯度下降。然后,就可以使用 .fit 传入数据开始迭代了。


batch_size 是采用小批次训练的参数,主要用于解决一次性传入数据过多无法训练的问题。当然,由于示例数据本身较少,这里意义不大,但还是按照常规使用方法进行设置。


你会发现,完全使用 Keras 高阶 API 实际上只需要 4 行核心代码即可完成,相比于低阶 API 简化了很多。


完整代码:

import tensorflow as tf
TRUE_W = 3.0
TRUE_b = 2.0
NUM_SAMPLES = 100
X = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
noise = tf.random.normal(shape=[NUM_SAMPLES,1]).numpy()
y = X * TRUE_W + TRUE_b + noise
# 模型训练
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=1,input_dim=1))
model.compile(optimizer='sgd',loss='mse')
model.fit(X,y,epochs=10,batch_size=32)
目录
相关文章
|
2天前
|
机器学习/深度学习 数据可视化 TensorFlow
Python用线性回归和TensorFlow非线性概率神经网络不同激活函数分析可视化
Python用线性回归和TensorFlow非线性概率神经网络不同激活函数分析可视化
|
2天前
|
机器学习/深度学习 算法 TensorFlow
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
|
2天前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow、Keras 和 Python 构建神经网络分析鸢尾花iris数据集|代码数据分享
TensorFlow、Keras 和 Python 构建神经网络分析鸢尾花iris数据集|代码数据分享
|
2天前
|
机器学习/深度学习 人工智能 算法框架/工具
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)(4)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)
36 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(八)(4)
|
2天前
|
机器学习/深度学习 算法框架/工具 TensorFlow
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(4)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
48 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(4)
|
2天前
|
机器学习/深度学习 算法 算法框架/工具
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(3)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
18 0
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(3)
|
2天前
|
机器学习/深度学习 算法框架/工具 自然语言处理
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)(1)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(七)
33 0
|
2天前
|
机器学习/深度学习 算法框架/工具 TensorFlow
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(五)(3)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(五)
12 0
|
2天前
|
机器学习/深度学习 算法框架/工具 Python
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(五)(2)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(五)
31 0
|
2天前
|
机器学习/深度学习 算法框架/工具 Python
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(五)(1)
Sklearn、TensorFlow 与 Keras 机器学习实用指南第三版(五)
37 0