机器学习中的数学原理——随机梯度下降法

简介: 机器学习中的数学原理——随机梯度下降法

一、什么是随机梯度下降法

随机梯度下降是随机取样替代完整的样本,主要作用是提高迭代速度,避免陷入庞大计算量的泥沼。 对于整个样本做GD又称为批梯度下降(BGD,batch gradient descent)。 随机梯度下降(SGD, stochastic gradient descent) :名字中已经体现了核心思想,随机选取一个店做梯度下降,而不是遍历所有样本后进行参数迭代。

二、算法分析

之前我们介绍过最速梯度下降法,即只要向与导数的符号相反的方向移动 x,g(x) 就会自然而然地沿着最小值的方向前进了。也就是自动更新参数。但是最速下降法除了计算花时间以外,还有一个缺点——容易陷入局部最优解。在讲解回归时,我们使用的是平方误差目标函数。这个函数形式 简单,所以用最速下降法也没有问题。现在我们来考虑稍微复杂 一点的,比如下列的图这种形状的函数:

最速下降法来找函数的最小值时,必须先要决定从哪个 x 开始 找起。之前我们用 g(x) 说明的时候是从 x = 3 或者 x = −1 开始的,这是我们随便选择的,选用随机数作为初始值的情况比较多。不过这样每次初始值都会变,进而导致陷入局部最优解的问题。假设这张图中标记的位置就是初始值:

那么从这个点开始找,可以求出最小值。但是如果将下列这个点作为初始值,还没计算完就会停止,这就叫陷入了局部最优解

随机梯度下降法就是以最速下降法为基础的。我们先复习一下最速下降法,还记得最速下降法的参数更新表达式吗?

这个表达式使用了所有训练数据误差,而在随机梯度下降法中会随机选择一个训练数据并使用它来更新参数。这个表达 式中的 k 就是被随机选中的数据索引

所以最速下降法更新 1 次参数的时间,随机梯度下降法可以更新 n 次,速度上明显提升。 此外,随机梯度下降法由于训练数据是随机选择的,更新参数时使用的又是选择数据时的梯度,所以不容易陷入目标函数的局部最优解。并且在实际计算过程中,的确会收敛!这样的做法就叫做随机梯度下降法

上述提到的是随机选择1个训练数据的做法,此外还有随机选择 m 个训练数据来更新参数的做法。设随机选择 m 个训练数据的索引的集合为 K,那么我们这样来更新参数:

假设训练数据有 100 个,那么在 m = 10 时,创建一个有 10 个随机数的索引的集合,例如 K = {61, 53, 59, 16, 30, 21, 85, 31, 51, 10},然后重复更新参数就可以了,这样的做法称为小批量(mini-batch)梯度下降法,这是一种介于最速下降法和随机梯度下降法之间的方法

三、总结

不管是随机梯度下降法还是小批量梯度下降法,我们都必须考虑学习率 η。把 η 设置为合适的值是很重要的。学习率的决定是一个很难的问题,可以通过反复尝试来找到合适的值,除此之外还有其他的几个办法,这是我们后续所要学习的!这三个算法总结成一句话:最速梯度下降法是用每个对应的估计值减去实际值求和,随机梯度下降法是用选定的一个估计值减去实际值求和,批量梯度下降法是用选定的多个估计值减去实际值求和。


相关文章
|
3天前
|
机器学习/深度学习 算法 搜索推荐
【机器学习】机器学习的基本概念、算法的工作原理、实际应用案例
机器学习是人工智能的一个分支,它使计算机能够在没有明确编程的情况下从数据中学习并改进其性能。机器学习的目标是让计算机自动学习模式和规律,从而能够对未知数据做出预测或决策。
8 2
|
7天前
|
机器学习/深度学习 人工智能 关系型数据库
【机器学习】Qwen2大模型原理、训练及推理部署实战
【机器学习】Qwen2大模型原理、训练及推理部署实战
44 0
【机器学习】Qwen2大模型原理、训练及推理部署实战
|
14天前
|
机器学习/深度学习 运维 算法
深入探索机器学习中的支持向量机(SVM)算法:原理、应用与Python代码示例全面解析
【8月更文挑战第6天】在机器学习领域,支持向量机(SVM)犹如璀璨明珠。它是一种强大的监督学习算法,在分类、回归及异常检测中表现出色。SVM通过在高维空间寻找最大间隔超平面来分隔不同类别的数据,提升模型泛化能力。为处理非线性问题,引入了核函数将数据映射到高维空间。SVM在文本分类、图像识别等多个领域有广泛应用,展现出高度灵活性和适应性。
68 2
|
14天前
|
机器学习/深度学习
【机器学习】面试题:LSTM长短期记忆网络的理解?LSTM是怎么解决梯度消失的问题的?还有哪些其它的解决梯度消失或梯度爆炸的方法?
长短时记忆网络(LSTM)的基本概念、解决梯度消失问题的机制,以及介绍了包括梯度裁剪、改变激活函数、残差结构和Batch Normalization在内的其他方法来解决梯度消失或梯度爆炸问题。
27 2
|
14天前
|
机器学习/深度学习 算法 数据挖掘
|
7天前
|
机器学习/深度学习 数据采集 物联网
【机器学习】Google开源大模型Gemma2:原理、微调训练及推理部署实战
【机器学习】Google开源大模型Gemma2:原理、微调训练及推理部署实战
26 0
|
7天前
|
机器学习/深度学习 人工智能 自然语言处理
【机器学习】GLM4-9B-Chat大模型/GLM-4V-9B多模态大模型概述、原理及推理实战
【机器学习】GLM4-9B-Chat大模型/GLM-4V-9B多模态大模型概述、原理及推理实战
30 0
|
17天前
|
机器学习/深度学习 算法
【机器学习】梯度消失和梯度爆炸的原因分析、表现及解决方案
本文分析了深度神经网络中梯度消失和梯度爆炸的原因、表现形式及解决方案,包括梯度不稳定的根本原因以及如何通过网络结构设计、激活函数选择和权重初始化等方法来解决这些问题。
9 0
|
1月前
|
机器学习/深度学习 自然语言处理 算法
扩散模型在机器学习中的应用及原理
扩散模型在机器学习中的应用及原理
|
2月前
|
机器学习/深度学习 算法 BI
机器学习笔记(一) 感知机算法 之 原理篇
机器学习笔记(一) 感知机算法 之 原理篇