20分钟搞懂神经网络BP算法

简介: 通过一个具体的例子来说明神经网络中的BP算法,使大家能够很直观地感受BP算法的过程,对BP算法加深了解和认识。

在学习深度学习过程中,无意间发现一篇介绍BP算法的文章,感觉非常直观,容易理解。这篇文章的最大亮点是:不像其他介绍BP算法的文章,用一堆数据符号和公式来推导。文中通过使用一条具体的样本数据,为我们展示了模型训练中的参数迭代计算过程,为我们理解BP算法提供了很直观的理解视角;其次,作者也给出了使用python来实现BP的算法。只要你了解过传统神经网络结构以及大学微积分的知识,都可以毫不费力的在20分钟内完全理解BP算法。这里整理出来,供大家学习参考。要看原文的同学,直接跳到文末点击原文链接。


在开始之前,提醒下大家,注意公式中的下标,结合网络结构帮忙我们理解算法推导计算过程和细节。

网络结构和样本数据

跟所有训练神经网络或深度学习模型的流程一样,首先要先确定网络结构。这里为了介绍上的方便,以2个输入节点,2个隐藏节点,2个输出节点的网络(包括bias项)为例,展开对BP算法的介绍。如下图所示:
nn_BP_1.png
下面作者开始引入网络中参数的初始权重,以及一个训练样本,如下图中节点和边上的数值:
nn_BP_2.png

BP算法的目标就是优化神经网络的权重使得学习到的模型能够将输入值正确地映射到实际的输出值(也就是,希望模型能够模型真实数据产生的机制。在统计学中就是,我们要学习一个统计模型(统计分布函数),使得真实数据分布与统计模型产生的样本分布尽可能一致)。

如上图所示,下面的参数求解迭代过程,就是为了使得输入样本是0.05和0.10时(一个2维的样本数据),神经网络的输出与0.01和0.99接近。


前向传播过程

前向传播很简单,就是在已经给定的数据和参数下,按照网络结构来逐层传递数据,最后在输出层计算网络的输出与样本真实的目标值误差,这个误差就是模型的目标函数。

具体到这个case中,在给定模型输出权重和bias的条件下,我们需要把样本数据(0.05,0.10)通过图二中的网络逐步向后传递,看网络的输出与实际的输出的差异。

下面推导计算过程中,网络中使用的激活函数是logistic函数(或sigmoid函数):

σ(x)=11+ex

首先来计算隐藏节点 h1 的输入值:

neth1=0.150.05+0.20.1+0.351=0.3775

得到h1的输入值后,我们使用激活函数(logistic函数)来将输入值转化为为h1的输出值:

outh1=11+eneth1=11+e0.3775=0.593269992

按同样的方式,我们可以计算h2的输出值:

outh2=0.596884378

类似于计算h1h2的过程,我们可以计算输出层节点o1o2的值。下面是o1的输出值计算过程:

neto1=w5outh1+w6outh2+b21

neto1=0.40.593269992+0.450.596884378+0.61=1.105905967

outo1=11+eneto1=11+e1.105905967=0.75136507

同样的方式,o2的输出值为:

outo2=0.772928465

计算模型总误差

得到了网络的输出值后,就可以计算输出值与真实值之间的误差。这里我们使用平方误差来计算模型总误差:

Etotal=12(targetoutput)2

上式中的target就是样本目标值,或真实值。12只是为了计算上的整洁,对实际参数的估计没有影响 。(The 12 is included so that exponent is cancelled when we differentiate later on. The result is eventually multiplied by a learning rate anyway so it doesn’t matter that we introduce a constant here。)

对于输出节点o1的误差为:

Eo1=12(targeto1outo1)2=12(0.010.75136507)2=0.274811083

类似的计算方法,o2的误差为:

Eo2=0.023560026

最后,通过这个前向传递后,这个神经网络的总误差为:

Etotal=Eo1+Eo2=0.274811083+0.023560026=0.298371109

后向传播过程

后向传播过程就是迭代网络参数的过程,通过误差的后向传播得到新的模型参数,基于这个新的模型参数,再经过下一次的前向传播,模型误差会减小,从而使得模型输出值与实际值越接近。

输出层(output layer)

我们先来看了离误差最近的输出层中涉及的参数。以w5为例,我们想知道w5的改变对整体误差的影响,那么我们自然会想到对模型总误差求关于w5的偏导数Etotalw5。这个值也称为误差在w5方向上的梯度。
应用求导的链式法则,我们可以对偏导数Etotalw5进行如下的改写:

Etotalw5=Etotalouto1outo1neto1neto1w5

这个公式可以对应到具体的相应网络结构:
nn_BP_3.png
为了得到Etotalw5的值,我们需要计算上式中的每个因子的值。首先我们来计算误差关于o1输出值的偏导数,计算方式如下:

Etotal=12(targeto1outo1)2+12(targeto2outo2)2

Etotalouto1=212(targeto1outo1)211+0

Etotalouto1=(targeto1outo1)=(0.010.75136507)=0.74136507

下一步就是要计算outo1neto1,这个值的含义如上图中所示,就是激活函数对自变量的求导:

outo1=11+eneto1

outo1neto1=outo1(1outo1)=0.75136507(10.75136507)=0.186815602

logistic函数对自变量求导,可参考:https://en.wikipedia.org/wiki/Logistic_function#Derivative

现在还需要计算最后一个引子的值neto1w5,这里neto1就是激活函数的输入值:

neto1=w5outh1+w6outh2+b21

那么对w5求偏导就很直接了:

neto1w5=1outh1w(11)5+0+0=outh1=0.593269992

得到三个因子后,我们就得到了总误差关于w5的偏导数:

Etotalw5=Etotalouto1outo1neto1neto1w5

Etotalw5=0.741365070.1868156020.593269992=0.082167041

为了减小误差,我们就可以类似于梯度下降的方式,来更新w5的值:

w+5=w5ηEtotalw5=0.40.50.082167041=0.35891648

上式中的η为学习率(learning rate),这里设为0.5. 在实际训练模型中,需要根据实际样本数据和网络结构来进行调整。

以类似的方式,我们同样可以得到 w6,w7,w8的更新值:
w+6=0.408666186
w+7=0.511301270
w+8=0.561370121
至此,我们得到了输出层节点中的参数更新值。下面我们以同样的方式来更新隐藏层节点中的参数值。

隐藏层 (hidden layer)

在隐藏层中,同样地,我们对总误差求关于w1,w2,w3,w4的偏导数,来获得更新值。首先还是应用求导的链式法则对总误差关于w1,w2,w3,w4的偏导数,以w1为例,分解如下:

Etotalw1=Etotalouth1outh1neth1neth1w1

用网络结构图来表示如下,从图中可以更直观地理解这种分解的物理意义:
nn_BP_4.png

与输出层中对权重求偏导数不同的一个地方是,由于每个隐藏层节点都会影响所有的输出层节点,在求总误差对隐藏层的输出变量求偏导数时,需要对组成总误差的每个输出层节点误差进行分别求偏导数。具体如下:

Etotalouth1=Eo1outh1+Eo2outh1

我们先来求第一项Eo1outh1的值,过程如下:

Eo1outh1=Eo1neto1neto1outh1

Eo1neto1=Eo1outo1outo1neto1=0.741365070.186815602=0.138498562

这一步可以利用输出层的计算结果。

neto1=w5outh1+w6outh2+b21

neto1outh1=w5=0.40

因此,

Eo1outh1=Eo1neto1neto1outh1=0.1384985620.40=0.055399425

类似地,我们可以求得Eo2outh1的值:

Eo2outh1=0.019049119

那么我们就可以得到Etotalouth1的值:

Etotalouth1=Eo1outh1+Eo2outh1=0.055399425+0.019049119=0.036350306

我们还需要计算outh1neth1neth1w就可以得到Etotalw1的值了。这两个值的计算方法跟输出层的完全类似,过程如下:

outh1=11+eneth1

outh1neth1=outh1(1outh1)=0.59326999(10.59326999)=0.241300709

neth1=w1i1+w3i2+b11

neth1w1=i1=0.05

最后把三个因子相乘就是我们需要计算的值:

Etotalw1=Etotalouth1outh1neth1neth1w1

Etotalw1=0.0363503060.2413007090.05=0.000438568

w1的更新值为:

w+1=w1ηEtotalw1=0.150.50.000438568=0.149780716

同样的方式,w2,w3,w4的更新值为:

w+2=0.19956143

w+3=0.24975114

w+4=0.29950229

从上面更新隐藏层节点参数的过程中,我们可以看到,这里的更新并没有用到输出层节点更新后的参数的值,还是基于老的参数来进行的。这个不能搞混。

上面的计算中,并没有对bias项的权重进行更新,更新方式其实也很简单。可以类似操作。

至此,我们已经完成了一轮BP的迭代。经过这轮迭代后,基于新的参数,再走一遍前向传播来计算新的模型误差,这时已经下降到0.291027924,相比第一次的误差 0.298371109貌似没减少太多。但是我们重复这个过程10000次后,误差已经下降到0.0000351085,下降了很多。这时模型的输出结果为0.015912196和0.984065734,跟实际的结果0.01和0.99已经很接近了。


这里只是一个样本数据,那么我们有很多样本呢?很多样本的情况下的计算跟这一个样本数据相比,有什么不同呢?自己比划比划吧~


原文链接地址:https://mattmazur.com/2015/03/17/a-step-by-step-backpropagation-example/
pyhont代码:https://github.com/mattm/simple-neural-network/blob/master/neural-network.py

附:神经网络入门材料:http://neuralnetworksanddeeplearning.com/index.html 可以整体上了解神经网络结构以及训练过程中存在的问题。虽然是英文,但使用的词汇都比较简单,看起来很顺畅

目录
打赏
0
0
0
1
1
分享
相关文章
解读 C++ 助力的局域网监控电脑网络连接算法
本文探讨了使用C++语言实现局域网监控电脑中网络连接监控的算法。通过将局域网的拓扑结构建模为图(Graph)数据结构,每台电脑作为顶点,网络连接作为边,可高效管理与监控动态变化的网络连接。文章展示了基于深度优先搜索(DFS)的连通性检测算法,用于判断两节点间是否存在路径,助力故障排查与流量优化。C++的高效性能结合图算法,为保障网络秩序与信息安全提供了坚实基础,未来可进一步优化以应对无线网络等新挑战。
基于 PHP 语言深度优先搜索算法的局域网网络监控软件研究
在当下数字化时代,局域网作为企业与机构内部信息交互的核心载体,其稳定性与安全性备受关注。局域网网络监控软件随之兴起,成为保障网络正常运转的关键工具。此类软件的高效运行依托于多种数据结构与算法,本文将聚焦深度优先搜索(DFS)算法,探究其在局域网网络监控软件中的应用,并借助 PHP 语言代码示例予以详细阐释。
23 1
企业用网络监控软件中的 Node.js 深度优先搜索算法剖析
在数字化办公盛行的当下,企业对网络监控的需求呈显著增长态势。企业级网络监控软件作为维护网络安全、提高办公效率的关键工具,其重要性不言而喻。此类软件需要高效处理复杂的网络拓扑结构与海量网络数据,而算法与数据结构则构成了其核心支撑。本文将深入剖析深度优先搜索(DFS)算法在企业级网络监控软件中的应用,并通过 Node.js 代码示例进行详细阐释。
19 2
基于MobileNet深度学习网络的活体人脸识别检测算法matlab仿真
本内容主要介绍一种基于MobileNet深度学习网络的活体人脸识别检测技术及MQAM调制类型识别方法。完整程序运行效果无水印,需使用Matlab2022a版本。核心代码包含详细中文注释与操作视频。理论概述中提到,传统人脸识别易受非活体攻击影响,而MobileNet通过轻量化的深度可分离卷积结构,在保证准确性的同时提升检测效率。活体人脸与非活体在纹理和光照上存在显著差异,MobileNet可有效提取人脸高级特征,为无线通信领域提供先进的调制类型识别方案。
基于模糊神经网络的金融序列预测算法matlab仿真
本程序为基于模糊神经网络的金融序列预测算法MATLAB仿真,适用于非线性、不确定性金融数据预测。通过MAD、RSI、KD等指标实现序列预测与收益分析,运行环境为MATLAB2022A,完整程序无水印。算法结合模糊逻辑与神经网络技术,包含输入层、模糊化层、规则层等结构,可有效处理金融市场中的复杂关系,助力投资者制定交易策略。
基于PSO粒子群优化的CNN-LSTM-SAM网络时间序列回归预测算法matlab仿真
本项目展示了基于PSO优化的CNN-LSTM-SAM网络时间序列预测算法。使用Matlab2022a开发,完整代码含中文注释及操作视频。算法结合卷积层提取局部特征、LSTM处理长期依赖、自注意力机制捕捉全局特征,通过粒子群优化提升预测精度。适用于金融市场、气象预报等领域,提供高效准确的预测结果。
JS数组操作方法全景图,全网最全构建完整知识网络!js数组操作方法全集(实现筛选转换、随机排序洗牌算法、复杂数据处理统计等情景详解,附大量源码和易错点解析)
这些方法提供了对数组的全面操作,包括搜索、遍历、转换和聚合等。通过分为原地操作方法、非原地操作方法和其他方法便于您理解和记忆,并熟悉他们各自的使用方法与使用范围。详细的案例与进阶使用,方便您理解数组操作的底层原理。链式调用的几个案例,让您玩转数组操作。 只有锻炼思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
|
22天前
|
公司电脑网络监控场景下 Python 广度优先搜索算法的深度剖析
在数字化办公时代,公司电脑网络监控至关重要。广度优先搜索(BFS)算法在构建网络拓扑、检测安全威胁和优化资源分配方面发挥重要作用。通过Python代码示例展示其应用流程,助力企业提升网络安全与效率。未来,更多创新算法将融入该领域,保障企业数字化发展。
42 10
基于GA遗传优化TCN-LSTM时间卷积神经网络时间序列预测算法matlab仿真
本项目基于MATLAB 2022a实现了一种结合遗传算法(GA)优化的时间卷积神经网络(TCN)时间序列预测算法。通过GA全局搜索能力优化TCN超参数(如卷积核大小、层数等),显著提升模型性能,优于传统GA遗传优化TCN方法。项目提供完整代码(含详细中文注释)及操作视频,运行后无水印效果预览。 核心内容包括:1) 时间序列预测理论概述;2) TCN结构(因果卷积层与残差连接);3) GA优化流程(染色体编码、适应度评估等)。最终模型在金融、气象等领域具备广泛应用价值,可实现更精准可靠的预测结果。
|
25天前
|
基于 C# 网络套接字算法的局域网实时监控技术探究
在数字化办公与网络安全需求增长的背景下,局域网实时监控成为企业管理和安全防护的关键。本文介绍C#网络套接字算法在局域网实时监控中的应用,涵盖套接字创建、绑定监听、连接建立和数据传输等操作,并通过代码示例展示其实现方式。服务端和客户端通过套接字进行屏幕截图等数据的实时传输,保障网络稳定与信息安全。同时,文章探讨了算法的优缺点及优化方向,如异步编程、数据压缩与缓存、错误处理与重传机制,以提升系统性能。
39 2

热门文章

最新文章