线性回归 梯度下降算法大全与基于Python的底层代码实现

简介: 梯度下降是一种常用的优化算法,它通过不断迭代来最小化一个损失函数。根据不同的损失函数和迭代方式,梯度下降可以被分为批量梯度下降(Batch Gradient Descent,BGD)、随机梯度下降(Stochastic Gradient Descent,SGD)、小批量梯度下降(Mini-batch Gradient Descent)、共轭梯度法(Conjugate Gradient,CG)等。

梯度下降是一种常用的优化算法,它通过不断迭代来最小化一个损失函数。根据不同的损失函数和迭代方式,梯度下降可以被分为批量梯度下降(Batch Gradient Descent,BGD)、随机梯度下降(Stochastic Gradient Descent,SGD)、小批量梯度下降(Mini-batch Gradient Descent)、共轭梯度法(Conjugate Gradient,CG)等。

1 批量梯度下降(Batch Gradient Descent,BGD)

批量梯度下降(Batch Gradient Descent,BGD)是一种非常常见的梯度下降算法,它通过在每一次迭代中计算所有训练样本的梯度来更新模型参数。其具体算法参见上一篇博文:线性回归 梯度下降原理与基于Python的底层代码实现


GD 的优点包括:


可以保证收敛到全局最优解,特别是在凸优化问题中。

可以利用矩阵运算,加速计算过程。

对于比较稠密的数据集,BGD 的计算速度较快。

BGD 的缺点包括:


需要处理所有训练样本,计算量较大,因此不适合处理大规模数据集。

可能会陷入局部最优解,特别是在非凸优化问题中。

对于比较稀疏的数据集,BGD 的计算效率较低,因为大部分数据都是无关的。

由于每次迭代都需要处理整个数据集,因此在处理在线学习或实时学习等实时数据流问题时,BGD 的计算效率也较低。

2 随机梯度下降(Stochastic Gradient Descent,SGD)

随机梯度下降(Stochastic Gradient Descent,SGD)是一种常用的梯度下降算法,它通过在每一次迭代中计算一个训练样本的梯度来更新模型参数。


SGD 的优点包括:


可以处理大规模、稀疏或实时数据流问题,因为每次只处理一个样本,计算效率较高。

可以跳出局部最优解,因为每次更新参数的方向不一定是相同的。

对于非凸优化问题,SGD 的表现可能会更好,因为它能够跳出局部最优解。

SGD 的缺点包括:


可能无法保证收敛到全局最优解,因为更新方向是随机的。

在处理比较稠密的数据集时,SGD 的计算速度可能较慢,因为需要频繁读取数据。

可能会出现震荡或抖动的情况,导致收敛速度较慢。

随机梯度下降算法适合处理大规模、稀疏或实时数据流问题,并且能够跳出局部最优解。但是,对于小规模、稠密或需要保证全局最优解的问题,SGD 的表现可能会不如批量梯度下降算法。同时,SGD 的收敛速度可能会受到震荡或抖动的影响,需要进行一些额外的优化或调整。


3 小批量梯度下降(Mini-batch Gradient Descent, MGD)

小批量梯度下降(Mini-batch Gradient Descent,MBGD)是一种介于批量梯度下降(Batch Gradient Descent,BGD)和随机梯度下降(Stochastic Gradient Descent,SGD)之间的梯度下降算法,它通过在每一次迭代中计算一小部分训练样本的梯度来更新模型参数。


MBGD 的优点包括:


可以利用矩阵运算,加速计算过程。

对于比较稠密的数据集,MBGD 的计算速度较快。

对于比较稀疏的数据集,MBGD 的计算效率也比较高。

可以保证收敛到全局最优解,特别是在凸优化问题中。

可以跳出局部最优解,因为每次更新参数的方向不一定是相同的。

MBGD 的缺点包括:


需要手动设置小批量大小,如果选择不当,可能会影响收敛速度和精度。

对于大规模、稀疏或实时数据流问题,MBGD 的计算效率可能不如 SGD,但比 BGD 要好。

小批量梯度下降算法是一种折中的梯度下降算法,可以在一定程度上平衡计算效率和收敛速度,适用于大部分深度学习模型的训练。但是,需要根据具体情况来选择小批量大小,以获得最好的效果。


4 共轭梯度法(Conjugate Gradient,CG)

共轭梯度法(Conjugate Gradient,CG)是一种针对特殊的矩阵结构进行求解的迭代方法,它可以快速收敛到全局最优解。CG 方法是一种迭代算法,每次更新的方向不同于梯度方向,但会沿着前一次更新方向和当前梯度方向的线性组合方向进行更新。CG 算法的迭代过程可以描述为以下步骤:


随机初始化模型参数。

计算梯度,并将其作为初始搜索方向。

沿着搜索方向更新模型参数。

计算新的梯度,并计算一个新的搜索方向,使得该方向与前一次搜索方向共轭。

重复步骤 3-4,直到达到预定的迭代次数或误差阈值。

CG 算法的优点包括:


可以快速收敛到全局最优解,特别是对于对称、正定的矩阵结构而言。

不需要存储所有历史梯度,可以节省内存空间。

在更新模型参数的方向上,CG 方法不需要进行线搜索,因此不需要设置学习率等参数。

CG 算法的缺点包括:


只适用于特定类型的矩阵结构,特别是对称、正定的矩阵结构而言。

对于非凸优化问题,CG 的表现可能会不如其他梯度下降算法。

由于需要额外的内存来存储一些临时变量,因此在处理大规模问题时可能会受到限制。

共轭梯度法适用于对称、正定的矩阵结构,可以快速收敛到全局最优解,并且不需要进行线搜索。但是,对于非凸优化问题和大规模问题,CG 的表现可能会受到一些限制。


5 不同梯度方法的底层代码实例

5.1 构造数据集

此处我们假定数据集仅x一个变量,x与y的关系为y = 8 x y=8xy=8x。下面将构造100个数据,x的取值范围为range(0, 10, 0.1)。

EXAMPLE_NUM = 100  
BATCH_SIZE = 10    
TRAIN_STEP = 150  
LEARNING_RATE = 0.0001  
X_INPUT = np.arange(EXAMPLE_NUM) * 0.1  
Y_OUTPUT_CORRECT = 8 * X_INPUT + np.random.uniform(low=-10, high=10)
def train_func(X, K):  
    result = K * X  
    return result


此处EXAMPLE_NUM为数据个数;BATCH_SIZE为小批量梯度下降每次使用的数据个数;TRAIN_STEP为迭代次数;X_INPUT为构造的x取值范围;Y_OUTPUT_CORRECT为对应的y真实值,这里根据xy的映射关系,在数据集上加入了(-10,10)的噪音。同时构造了train_func,用于后面梯度下降寻找最佳k值的使用。


5.2 BGD回归

k_BGD = 0.0  
k_BGD_RECORD = [0]  
for step in range(TRAIN_STEP):  
    SUM_BGD = 0  
    for index in range(len(X_INPUT)):  
     SUM_BGD += (train_func(X_INPUT[index], k_BGD) - Y_OUTPUT_CORRECT[index]) * X_INPUT[index]
    k_BGD -= LEARNING_RATE * SUM_BGD  
    k_BGD_RECORD.append(k_BGD)


k_BGD为给定的初始k值,k_BGD_RECORD用于记录梯度下降过程中K值的变化。第一个循环为在150次训练中的循环,第二个循环为依次计算每一个数据的梯度,并将所有计算结果求和。这里SUM_BGD为损失函数的导数。


损失函数为:


image.png

对损失函数求导后得:


image.png

image.png

在代码中没有写出1 m \frac{1}{m} ,是因为其本身也是一个常数,可以整合到学习速率之中。(因此相比于随机梯度下降,批量梯度下降的学习速率可以小一些,否则会学习过快)


5.3 SGD回归

k_SGD = 0.0  
k_SGD_RECORD = [0]  
for step in range(TRAIN_STEP*10):  
    index = np.random.randint(len(X_INPUT))  
    SUM_SGD = (train_func(X_INPUT[index], k_SGD) - Y_OUTPUT_CORRECT[index]) * X_INPUT[index]  
    k_SGD -= LEARNING_RATE * SUM_SGD  
    if step%10==0:  
        k_SGD_RECORD.append(k_SGD)


SGD就只有一个循环了,因为SGD每次只使用一个数据,同时考虑到训练速度,对其训练周期进行了10倍扩增。每计算一个数据的梯度之后,都会对k进行更新。由于k的更新较慢,因此我们采取每隔10次才记录一次k的变化值。


5.4 MGD回归

k_MBGD = 0.0  
k_MBGD_RECORD = [0]  
for step in range(TRAIN_STEP):  
    SUM_MBGD = 0  
    index_start = np.random.randint(len(X_INPUT) - BATCH_SIZE)  
    for index in np.arange(index_start, index_start+BATCH_SIZE):  
        SUM_MBGD += (train_func(X_INPUT[index], k_MBGD) - Y_OUTPUT_CORRECT[index]) * X_INPUT[index]  
    k_MBGD -= LEARNING_RATE * SUM_MBGD  
    k_MBGD_RECORD.append(k_MBGD)


MGD与BGD的代码很类似,只是每个step的数据只有BATCH_SIZE个。需要注意在随机选择数据起点时,其范围是0至len(X_INPUT) - BATCH_SIZE,以免数据的选择范围超出数据量。


5.4 不同方法的对比绘图

plt.plot(np.arange(TRAIN_STEP+1), np.array(k_BGD_RECORD), label='BGD')  
plt.plot(np.arange(TRAIN_STEP+1), k_SGD_RECORD, label='SGD')  
plt.plot(np.arange(TRAIN_STEP+1), k_MBGD_RECORD, label='MBGD')  
plt.legend()  
plt.ylabel('K')  
plt.xlabel('step')  
plt.show()



4a59fd5830d45d96b36a4bcc23f3a796.png



可以看到BGD的训练效果最快,这是因为BGD的数据量比另外两种多了10倍。通常情况下,BGD的曲线会更加平滑,另外两种方法会有偶尔的偏离正确值的情况。但BGD无法避开局部最优,由于本函数不存在局部最优,因此三种效果的拟合方法都还不错。


相关文章
|
18天前
|
开发框架 数据建模 中间件
Python中的装饰器:简化代码,增强功能
在Python的世界里,装饰器是那些静悄悄的幕后英雄。它们不张扬,却能默默地为函数或类增添强大的功能。本文将带你了解装饰器的魅力所在,从基础概念到实际应用,我们一步步揭开装饰器的神秘面纱。准备好了吗?让我们开始这段简洁而富有启发性的旅程吧!
26 6
|
11天前
|
数据可视化 Python
以下是一些常用的图表类型及其Python代码示例,使用Matplotlib和Seaborn库。
通过这些思维导图和分析说明表,您可以更直观地理解和选择适合的数据可视化图表类型,帮助更有效地展示和分析数据。
52 8
|
18天前
|
API Python
【Azure Developer】分享一段Python代码调用Graph API创建用户的示例
分享一段Python代码调用Graph API创建用户的示例
41 11
|
20天前
|
测试技术 Python
探索Python中的装饰器:简化代码,增强功能
在Python的世界中,装饰器是那些能够为我们的代码增添魔力的小精灵。它们不仅让代码看起来更加优雅,还能在不改变原有函数定义的情况下,增加额外的功能。本文将通过生动的例子和易于理解的语言,带你领略装饰器的奥秘,从基础概念到实际应用,一起开启Python装饰器的奇妙旅程。
34 11
|
16天前
|
Python
探索Python中的装饰器:简化代码,增强功能
在Python的世界里,装饰器就像是给函数穿上了一件神奇的外套,让它们拥有了超能力。本文将通过浅显易懂的语言和生动的比喻,带你了解装饰器的基本概念、使用方法以及它们如何让你的代码变得更加简洁高效。让我们一起揭开装饰器的神秘面纱,看看它是如何在不改变函数核心逻辑的情况下,为函数增添新功能的吧!
|
17天前
|
程序员 测试技术 数据安全/隐私保护
深入理解Python装饰器:提升代码重用与可读性
本文旨在为中高级Python开发者提供一份关于装饰器的深度解析。通过探讨装饰器的基本原理、类型以及在实际项目中的应用案例,帮助读者更好地理解并运用这一强大的语言特性。不同于常规摘要,本文将以一个实际的软件开发场景引入,逐步揭示装饰器如何优化代码结构,提高开发效率和代码质量。
42 6
|
21天前
|
存储 算法 程序员
C 语言递归算法:以简洁代码驾驭复杂逻辑
C语言递归算法简介:通过简洁的代码实现复杂的逻辑处理,递归函数自我调用解决分层问题,高效而优雅。适用于树形结构遍历、数学计算等领域。
|
文字识别 算法 前端开发
100行Python代码实现一款高精度免费OCR工具
近期Github开源了一款基于Python开发、名为Textshot的截图工具,刚开源不到半个月已经500+Star。
100行Python代码实现一款高精度免费OCR工具
|
17天前
|
人工智能 数据可视化 数据挖掘
探索Python编程:从基础到高级
在这篇文章中,我们将一起深入探索Python编程的世界。无论你是初学者还是有经验的程序员,都可以从中获得新的知识和技能。我们将从Python的基础语法开始,然后逐步过渡到更复杂的主题,如面向对象编程、异常处理和模块使用。最后,我们将通过一些实际的代码示例,来展示如何应用这些知识解决实际问题。让我们一起开启Python编程的旅程吧!
|
16天前
|
存储 数据采集 人工智能
Python编程入门:从零基础到实战应用
本文是一篇面向初学者的Python编程教程,旨在帮助读者从零开始学习Python编程语言。文章首先介绍了Python的基本概念和特点,然后通过一个简单的例子展示了如何编写Python代码。接下来,文章详细介绍了Python的数据类型、变量、运算符、控制结构、函数等基本语法知识。最后,文章通过一个实战项目——制作一个简单的计算器程序,帮助读者巩固所学知识并提高编程技能。
下一篇
DataWorks