机器学习入门:梯度下降算法(下)

简介: 机器学习入门:梯度下降算法(下)

学习目标

🍀 了解全梯度下降,随机梯度下降,小批量梯度下降,随机平均梯度下降的原理

🍔 全梯度下降算法(FGD)

全梯度下降算法(FGD)-----每次迭代时, 使用全部样本的梯度值

批量梯度下降法,是梯度下降法最常用的形式,具体做法也就是在更新参数时使用所有的样本来进行更新。

计算训练集所有样本误差对其求和再取平均值作为目标函数

权重向量沿其梯度相反的方向移动,从而使当前目标函数减少得最多。

其是在整个训练数据集上计算损失函数关于参数\theta 的梯度:

由于我们有m个样本,这里求梯度的时候就用了所有m个样本的梯度数据。

如下图💯:

注意:

  • 因为在执行每次更新时,我们需要在整个数据集上计算所有的梯度,所以批梯度下降法的速度会很慢,同时,全梯度下降法无法处理超出内存容量限制的数据集。
  • 全梯度下降法同样也不能在线更新模型,即在运行的过程中,不能增加新的样本

🍔 随机梯度下降算法(SGD)

  • 随机梯度下降算法(SGD)
  • 每次迭代时, 随机选择并使用一个样本梯度值

由于FG每迭代更新一次权重都需要计算所有样本误差,而实际问题中经常有上亿的训练样本,故效率偏低,且容易陷入局部最优解,因此提出了随机梯度下降算法。

其每轮计算的目标函数不再是全体样本误差,而仅是单个样本误差,即 每次只代入计算一个样本目标函数的梯度来更新权重,再取下一个样本重复此过程,直到损失函数值停止下降或损失函数值小于某个可以容忍的阈值。

此过程简单,高效,通常可以较好地避免更新迭代收敛到局部最优解。其迭代形式为

但是由于,SG每次只使用一个样本迭代,若遇上噪声则容易陷入局部最优解。

Sklearn提供了随机梯度下降的API

from sklearn.linear_model import SGDRegressor

🍔 小批量梯度下降算法(mini-bantch)

  • 小批量梯度下降算法(mini-bantch)
  • 每次迭代时, 随机选择并使用小批量的样本梯度值

小批量梯度下降算法是FG和SG的折中方案,在一定程度上兼顾了以上两种方法的优点。

每次从训练样本集上随机抽取一个小样本集,在抽出来的小样本集上采用FG迭代更新权重。

被抽出的小样本集所含样本点的个数称为batch_size,通常设置为2的幂次方,更有利于GPU加速处理。

特别的,若batch_size=1,则变成了SG;若batch_size=n,则变成了FG.其迭代形式为

上式中,也就是我们从m个样本中,选择x个样本进行迭代(1<x<m),

🍔 随机平均梯度下降算法(SAG)

随机平均梯度下降算法(SAG)

  • 每次迭代时, 随机选择一个样本的梯度值和以往样本的梯度值的均值

在SG方法中,虽然避开了运算成本大的问题,但对于大数据训练而言,SG效果常不尽如人意,因为每一轮梯度更新都完全与上一轮的数据和梯度无关。

随机平均梯度算法克服了这个问题,在内存中为每一个样本都维护一个旧的梯度,随机选择第i个样本来更新此样本的梯度,其他样本的梯度保持不变,然后求得所有梯度的平均值,进而更新了参数。

如此,每一轮更新仅需计算一个样本的梯度,计算成本等同于SG,但收敛速度快得多。

其迭代形式为:

  • 我们知道sgd是当前权重减去步长乘以梯度,得到新的权重。sag中的a,就是平均的意思,具体说,就是在第k步迭代的时候,我考虑的这一步和前面n-1个梯度的平均值,当前权重减去步长乘以最近n个梯度的平均值。
  • n是自己设置的,当n=1的时候,就是普通的sgd。
  • 这个想法非常的简单,在随机中又增加了确定性,类似于mini-batch sgd的作用,但不同的是,sag又没有去计算更多的样本,只是利用了之前计算出来的梯度,所以每次迭代的计算成本远小于mini-batch sgd,和sgd相当。效果而言,sag相对于sgd,收敛速度快了很多。这一点下面的论文中有具体的描述和证明。
  • SAG论文链接:https://arxiv.org/pdf/1309.2388.pdf

🍔 小结

🍬 全梯度下降算法(FG):在进行梯度下降迭代时,所有样本均参与计算

🍬 随机梯度下降算法(SG):在进行梯度下降迭代时,每次迭代只选取一个样本进行计算

🍬 小批量梯度下降算法(mini-batch):在进行梯度下降迭代时,每次迭代只选取一部分样本进行计算

🍬 随机平均梯度下降算法(SAG):每次迭代时, 随机选择一个样本的梯度值和以往样本的梯度值的均值

相关文章
|
2月前
|
机器学习/深度学习 数据采集 人工智能
20分钟掌握机器学习算法指南
在短短20分钟内,从零开始理解主流机器学习算法的工作原理,掌握算法选择策略,并建立对神经网络的直观认识。本文用通俗易懂的语言和生动的比喻,帮助你告别算法选择的困惑,轻松踏入AI的大门。
147 7
|
3月前
|
机器学习/深度学习 存储 Kubernetes
【重磅发布】AllData数据中台核心功能:机器学习算法平台
杭州奥零数据科技有限公司成立于2023年,专注于数据中台业务,维护开源项目AllData并提供商业版解决方案。AllData提供数据集成、存储、开发、治理及BI展示等一站式服务,支持AI大模型应用,助力企业高效利用数据价值。
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
AI训练师入行指南(三):机器学习算法和模型架构选择
从淘金到雕琢,将原始数据炼成智能珠宝!本文带您走进数字珠宝工坊,用算法工具打磨数据金砂。从基础的经典算法到精密的深度学习模型,结合电商、医疗、金融等场景实战,手把手教您选择合适工具,打造价值连城的智能应用。掌握AutoML改装套件与模型蒸馏术,让复杂问题迎刃而解。握紧算法刻刀,为数字世界雕刻文明!
147 6
|
4月前
|
机器学习/深度学习 算法 机器人
强化学习:时间差分(TD)(SARSA算法和Q-Learning算法)(看不懂算我输专栏)——手把手教你入门强化学习(六)
本文介绍了时间差分法(TD)中的两种经典算法:SARSA和Q-Learning。二者均为无模型强化学习方法,通过与环境交互估算动作价值函数。SARSA是On-Policy算法,采用ε-greedy策略进行动作选择和评估;而Q-Learning为Off-Policy算法,评估时选取下一状态中估值最大的动作。相比动态规划和蒙特卡洛方法,TD算法结合了自举更新与样本更新的优势,实现边行动边学习。文章通过生动的例子解释了两者的差异,并提供了伪代码帮助理解。
327 2
|
5月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于机器学习的人脸识别算法matlab仿真,对比GRNN,PNN,DNN以及BP四种网络
本项目展示了人脸识别算法的运行效果(无水印),基于MATLAB2022A开发。核心程序包含详细中文注释及操作视频。理论部分介绍了广义回归神经网络(GRNN)、概率神经网络(PNN)、深度神经网络(DNN)和反向传播(BP)神经网络在人脸识别中的应用,涵盖各算法的结构特点与性能比较。
|
5月前
|
人工智能 编解码 算法
使用 PAI-DSW x Free Prompt Editing图像编辑算法,开发个人AIGC绘图小助理
使用 PAI-DSW x Free Prompt Editing图像编辑算法,开发个人AIGC绘图小助理
|
6月前
|
机器学习/深度学习 算法 网络安全
CCS 2024:如何严格衡量机器学习算法的隐私泄露? ETH有了新发现
在2024年CCS会议上,苏黎世联邦理工学院的研究人员提出,当前对机器学习隐私保护措施的评估可能存在严重误导。研究通过LiRA攻击评估了五种经验性隐私保护措施(HAMP、RelaxLoss、SELENA、DFKD和SSL),发现现有方法忽视最脆弱数据点、使用较弱攻击且未与实际差分隐私基线比较。结果表明这些措施在更强攻击下表现不佳,而强大的差分隐私基线则提供了更好的隐私-效用权衡。
174 14
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
解锁机器学习的新维度:元学习的算法与应用探秘
元学习作为一个重要的研究领域,正逐渐在多个应用领域展现其潜力。通过理解和应用元学习的基本算法,研究者可以更好地解决在样本不足或任务快速变化的情况下的学习问题。随着研究的深入,元学习有望在人工智能的未来发展中发挥更大的作用。
|
29天前
|
机器学习/深度学习 算法 数据挖掘
基于WOA鲸鱼优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本项目基于MATLAB 2022a/2024b实现,采用WOA优化的BiLSTM算法进行序列预测。核心代码包含完整中文注释与操作视频,展示从参数优化到模型训练、预测的全流程。BiLSTM通过前向与后向LSTM结合,有效捕捉序列前后文信息,解决传统RNN梯度消失问题。WOA优化超参数(如学习率、隐藏层神经元数),提升模型性能,避免局部最优解。附有运行效果图预览,最终输出预测值与实际值对比,RMSE评估精度。适合研究时序数据分析与深度学习优化的开发者参考。
|
19天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于PSO粒子群优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本项目基于MATLAB2022a/2024b开发,结合粒子群优化(PSO)算法与双向长短期记忆网络(BiLSTM),用于优化序列预测任务中的模型参数。核心代码包含详细中文注释及操作视频,涵盖遗传算法优化过程、BiLSTM网络构建、训练及预测分析。通过PSO优化BiLSTM的超参数(如学习率、隐藏层神经元数等),显著提升模型捕捉长期依赖关系和上下文信息的能力,适用于气象、交通流量等场景。附有运行效果图预览,展示适应度值、RMSE变化及预测结果对比,验证方法有效性。

热门文章

最新文章