三、参数学习

简介: 三、参数学习

1、梯度下降



为了使得代价函数最小化,引入梯度下降算法来寻找使得代价函数最小化的参数取值。


假设现在的代价函数为J(θ0, θ1),我们想要最小化代价函数 minJ(θ0, θ1)。为了进行上述寻优,首先我们需要给定一个初始值  θ0, θ1;之后需要改变 θ0, θ1的值来减小 J(θ0, θ1)直到最小化代价函数。这个重复改变参数  θ0, θ1的值的算法就是梯度下降算法,下面的图直观展示的梯度下降算法的执行过程。


1a78b7cb499e46968f1918361b1c598a.png469cd5550160428abc4eaef19fb51d0d.png


 


从上述两幅图可以看出,选择不同的起点之后,梯度下降算法找到的局部最优位置大不相同,这是梯度下降算法的一个特点。


梯度下降算法的算法流程如下所示:


Repeatuntilconvergence{θj:=θjαθjJ(θ0,θ1)}


其中,j=0,1表示特征的索引; :=表示赋值符号; α \alpha α表示每一步移动的步长,也叫学习率。


梯度下降算法在每个点处计算代价函数关于各个参数的偏导数(函数关于各个参数在当前点处的切线斜率),这个偏导数会指明下一步点应该移动的方向。若有多个偏导数,会选择其中梯度最大的方向进行移动,移动的距离由步长 α \alpha α来决定。


注意在使用梯度下降方法时,所有参数的更新需要同时进行,即针对同一个代价函数来更新所有的参数,不能更新完某个参数之后改变了代价函数之后再去更新其他参数。

2a043ef63021455fbe5c34283e7384e2.png





2、梯度下降算法怎样工作



下面通过单参数代价函数最小化为例来具体介绍梯度下降的工作流程:之前已经介绍了代价函数关于单参数 θ 1 \theta_1 θ1的图像类似于一个二次函数,首先当求导项大于0时,如下图所示:


6f952b9bfbcf4076a1f40d6305e90b42.png


可以发现,从开始的红色点出发,经过一次迭代之后,到达蓝色点之后,算法将代价函数的目标函数值缩小了,向着最低点进行了移动。


当求导项小于0时,如下图所示:3f084047b1f64e4fb5f74cbec0988a17.png



可以发现,从开始的红色点出发,经过一次迭代之后,到达蓝色点,算法同样会将代价函数的目标函数值缩小,向着最低点进行移动。


2.1 α \alpha α取值大小的影响


如果 α \alpha α的取值比较小,则将会耗费更多的时间才能收敛到代价函数的最低点:


be8e4059eb424110a9365bcd46d344b8.png

如果     α的取值过大,则算法可能不能达到收敛状态,甚至会出现发散状态:


bc823560a0d44bae9fde3eadff873c12.png



当采用合适的步长  α之后,在局部最优但,函数的梯度(导数/斜率)为0,函数收敛到局部最优解。

911de4f639234c3aa42781886e64dd25.png


因为越靠近局部最优解,梯度会越来越小,所以就算采用固定值的步长 α \alpha α,梯度下降算法也会自动采用越来越小的步长。


8fb10b6b7b0845039e9a1cb5bb4ddc34.png



2.2 应用梯度下降最小化双参数代价函数


梯度下降的流程如下所示:

Repeatuntilconvergence{θj:=θjαθjJ(θ0,θ1)}


线性回归模型如下所示:

h(x)=θ0+θ1xJ(θ0,θ1)=2m1i=1m(h(xi)yi)2


从上述式子中可以看出,为了应用梯度下降算法,最主要的工作需要求出代价函数关于参数 θ 0 \theta_0 θ0和 θ 1 \theta_1 θ1的偏导数。求导过程如下所示:


4e7afa342e8e416599098618c31bed66.png


将上述求出的偏导数带到梯度下降的流程中,可以得到下式:

Repeatuntilconvergence{θ0:=θ0αm1i=1m(h(xi)yi)θ1:=θ1αm1i=1m(h(xi)yi)xi}


应用上述梯度下降算法,最终收敛的图像如下图所示:


image.png


上述介绍的梯度下降算法也可以叫"批"处理次梯度下降算法-Batch Gradient Descent,其中“批”指的是:次梯度下降的每一步都使用所有的训练集来进行计算。


相关文章
|
4月前
Ceres库中参数理解
Ceres库中参数的理解,特别是仿函数中传参的含义,并提供了一个LeetCode问题的链接,该问题要求找出数组中和为目标值的两个数。
|
Java
JVM参数调优基础-参数的类型详解(上)
JVM参数调优基础-参数的类型详解(上)
167 0
JVM参数调优基础-参数的类型详解(上)
|
前端开发
前端学习案例-参数默认值是函数1
前端学习案例-参数默认值是函数1
69 0
前端学习案例-参数默认值是函数1
|
前端开发
前端学习案例-参数默认值是函数2
前端学习案例-参数默认值是函数2
96 0
前端学习案例-参数默认值是函数2
|
PyTorch 算法框架/工具
torch 一个网络的参数通过训练后得到新的参数,如何再将这个网络参数初始化到定义这个网络的时候参数
可以使用PyTorch中的state_dict()方法将当前训练得到的网络参数保存为一个字典,然后在需要重新初始化网络参数时,可以通过load_state_dict()方法将之前保存的字典加载到网络模型中。具体步骤如下: 1. 在训练完成后,使用
231 0
|
Java Scala 开发者
作为参数的函数 | 学习笔记
快速学习作为参数的函数
|
开发者 Python
函数的参数| 学习笔记
快速学习函数的参数
|
开发者 Python
多个参数| 学习笔记
快速学习多个参数
|
异构计算
MMsegmentation教程-Config参数解释
MMsegmentation教程-Config参数解释
762 0
self.doubleSpinBox.setGeometry(QtCore.QRect(20, 25, 101, 22))参数讲解
self.doubleSpinBox.setGeometry(QtCore.QRect(20, 25, 101, 22))参数讲解
327 0

热门文章

最新文章