神经网络如何学习的?

简介: 毫无疑问,神经网络是目前使用的最流行的机器学习技术。所以我认为了解神经网络如何学习是一件非常有意义的事。


像下山一样,找到损失函数的最低点。

image


毫无疑问,神经网络是目前使用的最流行的机器学习技术。所以我认为了解神经网络如何学习是一件非常有意义的事。


为了能够理解神经网络是如何进行学习的,让我们先看看下面的图片:


image


如果我们把每一层的输入和输出值表示为向量,把权重表示为矩阵,把误差表示为向量,那么我们就得到了上述的一个神经网络的视图,它只是一系列向量函数的应用。也就是说,函数将向量作为输入,对它们进行一些转换,然后把变换后的向量输出。在上图中,每条线代表一个函数,它可以是一个矩阵乘法加上一个误差向量,也可以是一个激活函数。这些圆表示这些函数作用的向量。


例如,我们从输入向量开始,然后将其输入到第一个函数中,该函数用来计算其各分量的线性组合,然后我们将获得的向量作为输出。然后把这个向量作为激活函数的输入,如此类推,直到我们到达序列中的最后一个函数。最后一个函数的输出就是神经网络的预测值。


到目前为止,我们已经讨论过神经网络是如何得到输出的,这正是我们感兴趣的内容。我们知道神经网络只是将它的输入向量传递给一系列函数。但是这些函数要依赖于一些参数:权重和误差。


神经网络如何通过学习得到这些参数来获得好的预测呢?


让我们回想一下神经网络实际上是什么:实际上它只是一个函数,是由一个个小函数按顺序排列组成的大函数。这个函数有一组参数,在一开始,我们并不知道这些参数应该是什么,我们仅仅是随机初始化它们。因此在一开始神经网络会给我们一些随机的值。那么我们如何改进他们呢?在尝试改进它们之前,我们首先需要一种评估神经网络性能的方法。如果我们没有办法衡量模型的好坏,那么我们应该如何改进模型的性能?


为此,我们需要设计一个函数,这个函数将神经网络的预测值和数据集中的真实标签作为输入,将一个代表神经网络性能的数字作为输出。然后我们就可以将学习问题转化为求函数的最小值或最大值的优化问题。在机器学习领域,这个函数通常是用来衡量我们的预测有多糟糕,因此被称为损失函数。我们的问题就变成了找到使这个损失函数最小化的神经网络参数。


随机梯度下降算法


你可能很擅长从微积分中求函数的最小值。对于这种问题,通常取函数的梯度,令其等于0,求出所有的解(也称为临界点),然后从中选择使函数值最小的那一个。这就是全局最小值。我们能做同样的事情来最小化我们的损失函数吗?事实上是行不通的,主要的问题是神经网络的损失函数并不像微积分课本中常见的那样简洁明了。它是一个极其复杂的函数,有数千个、几十万个甚至数百万个参数。有时甚至不可能找到一个解决问题的收敛解。这个问题通常是通过迭代的方法来解决的,这些方法并不试图找到一个直接的解,而是从一个随机的解开始,并在每次迭代中尝试改进一点。最终,经过大量的迭代,我们将得到一个相当好的解决方案。


其中一种迭代方法是梯度下降法。你可能知道,一个函数的梯度给出了最陡的上升方向,如果我们取梯度的负值,它会给我们最陡下降的方向,也就是我们可以在这个方向上最快地达到最小值。因此,在每一次迭代(也可以将其称作一次训练轮次)时,我们计算损失函数的梯度,并从旧参数中减去它(乘以一个称为学习率的因子)以得到神经网络的新参数。


image


其中θ(theta)表示包含神经网络所有参数的向量。


在标准梯度下降法中,梯度是将整个数据集考虑进来并进行计算的。通常这是不可取的,因为该计算可能是昂贵的。在实践中,数据集被随机分成多个块,这些块被称为批。对每个批进行更新。这种方法就叫做随机梯度下降。


上面的更新规则在每一步只考虑在当前位置计算的梯度。这样,在损失函数曲面上运动的点的轨迹对任何变动都很敏感。有时我们可能想让这条轨迹更稳健。为此,我们使用了一个受物理学启发的概念:动量。我们的想法是,当我们进行更新时,也考虑到以前的更新,这会累积成一个变量Δθ。如果在同一个方向上进行更多的更新,那么我们将"更快"地朝这个方向前进,并且不会因为任何小的扰动而改变我们的轨迹。把它想象成速度。


image


其中α是非负因子,它可以决定旧梯度到底可以贡献多少值。当它为0时,我们不使用动量。


反向传播算法


我们如何计算梯度呢?回想一下神经网络和损失函数,它们只是一个函数的组合。那么如何计算复合函数的偏导数呢?我们可以使用链式法则。让我们看看下面的图片:


image


如果我们要计算损失函数对第一层权重参数的偏导数:我们首先让第一个线性表达式对权重参数求偏导,然后用这个结果乘上下一个函数(也就是激活函数)关于它前面函数输出内容的偏导数,一直执行这个操作,直到我们乘上损失函数关于最后一个激活函数的偏导数。那如果我们想要计算对第二层的权重参数求的导数呢?我们必须做同样的过程,但是这次我们从第二个线性组合函数对权重参数求导数开始,然后,我们要乘的其他项在计算第一层权重的导数时也出现了。所以,与其一遍又一遍地计算这些术语,我们将从后向前计算,因此得名为反向传播算法。


我们将首先计算出损失函数关于神经网络输出层的偏导数,然后通过保持导数的运行乘积将这些导数反向传播到第一层。需要注意的是,我们有两种导数:一种是函数关于它输入内容的导数。我们把它们乘以导数的乘积,目的是跟踪神经网络从输出层到当前层神经元节点的误差。第二类导数是关于参数的,这类导数是我们用来优化参数的。我们不把它与其它导数的乘积相乘,相反,我们将它们存储为梯度的一部分,稍后我们将使用它来更新参数。


所以,在反向传播时,当我们遇到没有可学习参数的函数时(比如激活函数),我们只取第一种的导数,只是为了反向传播误差。但是,当我们遇到的函数有可学的参数(如线性组合,有权重和偏差),那么我们取这两种导数:第一种是用误差传播的输入,第二种是加权和偏差,并将它们作为梯度的一部分来存储。整个过程,我们从损失函数开始,直到我们到达第一层,在这一层我们没有任何想要添加到梯度中的可学习参数。这就是反向传播算法。


Softmax激活和交叉熵损失函数


分类任务中,最后一层常用的激活函数是softmax函数。


image


softmax函数将其输入向量转换为概率分布。从上图中可以看到softmax的输出的向量元素都是正的,它们的和是1。当我们使用softmax激活时,我们在神经网络最后一层创建与数据集中类数量相等的节点,并且softmax激活函数将给出在可能的类上的概率分布。因此,神经网络的输出将会把输入向量属于每一个可能类的概率输出给我们,我们选择概率最高的类作为神经网络的预测。


当把softmax函数作为输出层的激活函数时,通常使用交叉熵损失作为损失函数。交叉熵损失衡量两个概率分布的相似程度。我们可以将输入值x的真实标签表示为一个概率分布:其中真实类标签的概率为1,其他类标签的概率为0。标签的这种表示也被称为一个热编码。然后我们用交叉熵来衡量网络的预测概率分布与真实概率分布的接近程度。


image


其中y是真标签的一个热编码,y hat是预测的概率分布,yi,yi hat是这些向量的元素。


如果预测的概率分布接近真实标签的一个热编码,那么损失函数的值将接近于0。否则如果它们相差很大,损失函数的值可能会无限大。


均方误差损失函数


softmax激活和交叉熵损失主要用于分类任务,而神经网络只需在最后一层使用适当的损失函数和激活函数就可以很容易地适应回归任务。例如,如果我们没有类标签作为依据,我们有一个我们想要近似的数字列表,我们可以使用均方误差(简称MSE)损失函数。通常,当我们使用MSE损失函数时,我们在最后一层使用身份激活(即f(x)=x)。


image


综上所述,神经网络的学习过程只不过是一个优化问题:我们要找到使损失函数最小化的参数。但这不是一件容易的事,有很多关于优化技术的书。而且,除了优化之外,对于给定的任务选择哪种神经网络结构也会出现问题。


我希望这篇文章对你有帮助,并十分感谢你的阅读。





本文作者:deephub



目录
相关文章
|
1月前
|
编解码 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(10-2):保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali——Liinux-Debian:就怕你学成黑客啦!)作者——LJS
保姆级别教会你如何搭建白帽黑客渗透测试系统环境Kali以及常见的报错及对应解决方案、常用Kali功能简便化以及详解如何具体实现
|
1月前
|
安全 网络协议 算法
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-1):主动信息收集之ping、Nmap 就怕你学成黑客啦!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-1):主动信息收集之ping、Nmap 就怕你学成黑客啦!
|
3月前
|
监控 网络协议 Linux
网络学习
网络学习
150 68
|
1月前
|
网络协议 安全 NoSQL
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!
scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练等具体操作详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法IKUN和I原们你这要是学不会我直接退出江湖;好吧!!!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(8-2):scapy 定制 ARP 协议 、使用 nmap 进行僵尸扫描-实战演练、就怕你学成黑客啦!
|
1月前
|
网络协议 安全 算法
网络空间安全之一个WH的超前沿全栈技术深入学习之路(9):WireShark 简介和抓包原理及实战过程一条龙全线分析——就怕你学成黑客啦!
实战:WireShark 抓包及快速定位数据包技巧、使用 WireShark 对常用协议抓包并分析原理 、WireShark 抓包解决服务器被黑上不了网等具体操作详解步骤;精典图示举例说明、注意点及常见报错问题所对应的解决方法IKUN和I原们你这要是学不会我直接退出江湖;好吧!!!
网络空间安全之一个WH的超前沿全栈技术深入学习之路(9):WireShark 简介和抓包原理及实战过程一条龙全线分析——就怕你学成黑客啦!
|
2月前
|
存储 安全 网络安全
浅谈网络安全的认识与学习规划
浅谈网络安全的认识与学习规划
38 6
|
1月前
|
人工智能 安全 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(4-2):渗透测试行业术语扫盲完结:就怕你学成黑客啦!)作者——LJS
|
1月前
|
安全 大数据 Linux
网络空间安全之一个WH的超前沿全栈技术深入学习之路(3-2):渗透测试行业术语扫盲)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(3-2):渗透测试行业术语扫盲)作者——LJS
|
1月前
|
SQL 安全 网络协议
网络空间安全之一个WH的超前沿全栈技术深入学习之路(1-2):渗透测试行业术语扫盲)作者——LJS
网络空间安全之一个WH的超前沿全栈技术深入学习之路(1-2):渗透测试行业术语扫盲)作者——LJS