LSTM(长短期记忆网络)原理介绍

简介: LSTM算法是一种重要的目前使用最多的时间序列算法,是一种特殊的RNN(Recurrent Neural Network,循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

LSTM算法(Long Short Term Memory, 长短期记忆网络 )


1.概念介绍

LSTM算法是一种重要的目前使用最多的时间序列算法,是一种特殊的RNN(Recurrent Neural Network,循环神经网络),能够学习长期的依赖关系。主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。


2.网络结构

所有RNN都具有神经网络的重复模块链的形式。 在标准的RNN中,该重复模块将具有非常简单的结构,例如单个tanh层。


标准的RNN网络如下图所示

image.png


LSTM也具有神经网络的重复模块链的形式。只是在RNN的基础上,每个重复模块增加了三个神经网络层,如下图所示:

image.png


 图中的绿色大框代表单元模块;黄色方框代表神经网络层;粉色圆圈代表逐点操作,例如矢量加法;箭头表示向量转换,从一个节点输出到另一个节点输入;合并的行表示串联,而分叉的行表示要复制的内容,并且副本将到达不同的位置。  


    和RNN不同的是: RNN中,就是个简单的线性求和的过程. 而LSTM可以通过“门”结构来去除或者增加“细胞状态”的信息,实现了对重要内容的保留和对不重要内容的去除. 通过Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允许任务变量通过”,1表示“运行所有变量通过 ”.


3.LSTM核心思想

image.png

首先CNN的主线就是这条顶部水平贯穿的线,也就是长期记忆C线(细胞状态),达到了序列学习的目的。而h可以看做是短期记忆,x代表事件信息,也就是输入。LSTM也是以这一条水平贯穿的C线为主线,在此基础上添加三个门,以保护控制单元状态。所以LSTM有删除或向单元状态添加信息的能力,都是由这门的结构来调节控制的。这个门(gate)是一种选择性的让信息通过的方式。它是由Sigmoid神经网络和矩阵逐点乘运算组成。


LSTM增加的三个神经网络层就代表LSTM的三个门(遗忘门、记忆门、输出门)



1.遗忘门

在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为忘记门层完成。该门会读取ht−1和xt,输出一个在 0到 1之间的数值给每个在细胞状态Ct−1中的数字。1 表示“完全保留”,0 表示“完全舍弃”。

dae6948d7fed85fd104b7ad3d26e8bb0_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDE2MjEwNA==,size_16,color_FFFFFF,t_70#pic_center.jpg

其中ht−1表示的是上一个cell的输出,xt表示的是当前细胞的输入。σσ表示sigmod函数。


2.输入门

下一步是决定让多少新的信息加入到 cell 状态 中来。实现这个需要包括两个 步骤:首先,一个叫做“input gate layer ”的 sigmoid 层决定哪些信息需要更新;一个 tanh 层生成一个向量,也就是备选的用来更新的内容,C^t 。在下一步,我们把这两部分联合起来,对 cell 的状态进行一个更新。

3e1da1758fd09edfe605e7fc82ae86fa_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDE2MjEwNA==,size_16,color_FFFFFF,t_70#pic_center.jpg

现在是更新旧细胞状态的时间了,Ct−1更新为Ct。前面的步骤已经决定了将会做什么,我们现在就是实际去完成。


我们把旧状态与ft相乘,丢弃掉我们确定需要丢弃的信息。接着加上it∗C~t。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。


3.输出门

最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。首先,我们运行一个 sigmoid 层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。

9f327a3a140bdccfd885386bc3cf0a8b_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDE2MjEwNA==,size_16,color_FFFFFF,t_70#pic_center.jpg


4.LSTM公式详解

遗忘门(forget gate):它决定了上一时刻的单元状态 c_t-1 有多少保留到当前时刻 c_t


输入门(input gate):它决定了当前时刻网络的输入 x_t 有多少保存到单元状态 c_t


输出门(output gate):控制单元状态 c_t 有多少输出到 LSTM 的当前输出值 h_t


公式:

cc28db14d4e7fe56b3b1d54881ee277a_94bbef9aac064427dde78e26be743643.png

遗忘门的计算为:

972575c1b136f0c8f6870eaf3cf0f2eb_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzMxMjc4OTAz,size_16,color_FFFFFF,t_70.png

遗忘门的计算公式中:

W_f 是遗忘门的权重矩阵,[h_t-1, x_t] 表示把两个向量连接成一个更长的向量,b_f是遗忘门的偏置项,σ 是 sigmoid 函数。

输入门的计算:

f2000e68d2e2e26ed03c83cf31351830_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzMxMjc4OTAz,size_16,color_FFFFFF,t_70.png

根据上一次的输出和本次输入来计算当前输入的单元状态:

7089f8f73854f79fdd8d7f6562eac7c5_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzMxMjc4OTAz,size_16,color_FFFFFF,t_70.png

当前时刻的单元状态 c_t 的计算:由上一次的单元状态 c_t-1 按元素乘以遗忘门 f_t,再用当前输入的单元状态 c_t 按元素乘以输入门 i_t,再将两个积加和:这样,就可以把当前的记忆 c_t 和长期的记忆 c_t-1 组合在一起,形成了新的单元状态 c_t。由于遗忘门的控制,它可以保存很久很久之前的信息,由于输入门的控制,它又可以避免当前无关紧要的内容进入记忆。

f5bee781a1aced346932661a51366a60_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzMxMjc4OTAz,size_16,color_FFFFFF,t_70.png

输出门的计算:

0053d0afb32e89981befc5e1a16dba52_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzMxMjc4OTAz,size_16,color_FFFFFF,t_70.png

思想:

三个门控制对前一段信息、输入信息以及输出信息的记忆状态,进而保证网络可以更好地学习到长距离依赖关系。

遗忘门(记忆门):通过判断当前输入信息的重要程度决定对过去信息的保留度

输入门:通过判断当前输入信息的重要程度决定对输入信息的保留度

输出门:当前输出有多大程度取决于当前记忆单元

激活函数:

门:sigmoid,0-1分布概率,符合门控的定义。且当输入较大或者较小时,值会接近1或0,进而控制开关。

候选记忆:

分布在-1~1之间,与大多场景下0中心分布吻合

在输入为0有较大的梯度,使模型更快收敛

存在问题:

不可并行,只能从前到后-->attention


5.LSTM其他解释

LSTM隐层神经元结构:

ac7aad8b7c19f28747e240aeed835f43_20161012234754776.png

LSTM隐层神经元详细结构:

30c8717dcd23e8e5e19a55a60f88ee08_20161016103332355.png

bb67ca213fd512d9aef0f3b454177a3a_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzQ2MDA2NDY4,size_16,color_FFFFFF,t_70.png

一般由以下部分组成:


输入门:it =sigmoid(权重乘以当前的输入特征和上个时间步的短期记忆ht-1,再加上偏置)


遗忘门:ft =sigmoid(权重乘以当前的输入特征和上个时间步的短期记忆ht-1,再加上偏置)


输出门:ot=sigmoid(权重乘以当前的输入特征和上个时间步的短期记忆ht-1,再加上偏置)


候选态:候选态即被输入门限制是否可以进入长期记忆的内容。


记忆体: 就是筛选长期记忆的过程,通过tanh函数转化为输出。


细胞态就是过去所有知识的累计, 它由两部分组成:当前时间步之前的长期记忆乘以遗忘门,以及输入门乘以候选态。


这不难理解,遗忘门判断过去累积的知识哪些可以被遗忘,输入门判断哪些知识可以被存起来,输出门则筛选知识进行输出。


当有多层网络是时,前一层的输出ht就是下一层的输入。


6.LSTM的代码实现

keras.layers.LSTM(units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, implementation=1, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False)

长短期记忆网络层(Long Short-Term Memory)

参数


units: 正整数,输出空间的维度。

activation: 要使用的激活函数 (详见 activations)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。

recurrent_activation: 用于循环时间步的激活函数 (详见 activations)。 默认:分段线性近似 sigmoid (hard_sigmoid)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。

use_bias: 布尔值,该层是否使用偏置向量。

kernel_initializer: kernel 权值矩阵的初始化器, 用于输入的线性转换 (详见 initializers)。

recurrent_initializer: recurrent_kernel 权值矩阵 的初始化器,用于循环层状态的线性转换 (详见 initializers)。

bias_initializer:偏置向量的初始化器 (详见initializers).

unit_forget_bias: 布尔值。 如果为 True,初始化时,将忘记门的偏置加 1。 将其设置为 True 同时还会强制 bias_initializer="zeros"。 这个建议来自 Jozefowicz et al.。

kernel_regularizer: 运用到 kernel 权值矩阵的正则化函数 (详见 regularizer)。

recurrent_regularizer: 运用到 recurrent_kernel 权值矩阵的正则化函数 (详见 regularizer)。

bias_regularizer: 运用到偏置向量的正则化函数 (详见 regularizer)。

activity_regularizer: 运用到层输出(它的激活值)的正则化函数 (详见 regularizer)。

kernel_constraint: 运用到 kernel 权值矩阵的约束函数 (详见 constraints)。

recurrent_constraint: 运用到 recurrent_kernel 权值矩阵的约束函数 (详见 constraints)。

bias_constraint: 运用到偏置向量的约束函数 (详见 constraints)。

dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于输入的线性转换。

recurrent_dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于循环层状态的线性转换。

implementation: 实现模式,1 或 2。 模式 1 将把它的操作结构化为更多的小的点积和加法操作, 而模式 2 将把它们分批到更少,更大的操作中。 这些模式在不同的硬件和不同的应用中具有不同的性能配置文件。

return_sequences: 布尔值。是返回输出序列中的最后一个输出,还是全部序列。

return_state: 布尔值。除了输出之外是否返回最后一个状态。

go_backwards: 布尔值 (默认 False)。 如果为 True,则向后处理输入序列并返回相反的序列。

stateful: 布尔值 (默认 False)。 如果为 True,则批次中索引 i 处的每个样品的最后状态 将用作下一批次中索引 i 样品的初始状态。

unroll: 布尔值 (默认 False)。 如果为 True,则网络将展开,否则将使用符号循环。 展开可以加速 RNN,但它往往会占用更多的内存。 展开只适用于短序列。


7.GRU

GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。

62c971c40a8353f5007c56ad524ab8e6_ab3e6c731e2d4ef0937b4fac304f1d5a.png

1、GRU概述

 GRU是LSTM网络的一种效果很好的变体,它较LSTM网络的结构更加简单,而且效果也很好,因此也是当前非常流形的一种网络。GRU既然是LSTM的变体,因此也是可以解决RNN网络中的长依赖问题。


 在LSTM中引入了三个门函数:输入门、遗忘门和输出门来控制输入值、记忆值和输出值。而在GRU模型中只有两个门:分别是更新门和重置门。具体结构如下图所示


GRU 结构

GRU分为重置门和更新门:

1d55118521869e181694a570e2c7465d_watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM2ODM1OTkx,size_16,color_FFFFFF,t_70.png

  图中的Zt和t分rt别表示更新门和重置门。更新门用于控制前一时刻的状态信息被带入到当前状态中的程度,更新门的值越大说明前一时刻的状态信息带入越多。重置门控制前一状态有多少信息被写入到当前的候选集 h~t 上,重置门越小,前一状态的信息被写入的越少。


2、GRU前向传播

85eaec0d22a794f8e248a59eef29e930_359a694c3762b6ddb4ed00dba36dd2e8.png

根据上面的GRU的模型图,我们来看看网络的前向传播公式:

image.png

其中[]表示两个向量相连,*表示矩阵的乘积。


3、GRU的训练过程

 从前向传播过程中的公式可以看出要学习的参数有Wr、Wz、Wh、Wo。其中前三个参数都是拼接的(因为后先的向量也是拼接的),所以在训练的过程中需要将他们分割出来:

image.png

 输出层的输入:

f9b04adcf7977c4b7d49db975dc2ece5_cad8190d66939b38664834b242190455.png

 输出层的输出:

5e81759de284eec59c41b9f614650b5f_4e8f5ae8c93cc46c00e81e766e84deca.png

在得到最终的输出后,就可以写出网络传递的损失,单个样本某时刻的损失为:

7bbbce5d62d7ee60396d5b3ee21faa5a_f15291b83a43d2a1c8659c4de9bfd621.png

 则单个样本的在所有时刻的损失为:

93db7c786dbd8094222af5b4dfa4d649_e7eda8a9af19fb75498ff6b146b70497.png

 采用后向误差传播算法来学习网络,所以先得求损失函数对各参数的偏导(总共有7个):  

image.png


 其中各中间参数为:

image.png

在算出了对各参数的偏导之后,就可以更新参数,依次迭代知道损失收敛。

 

概括来说,LSTM和CRU都是通过各种门函数来将重要特征保留下来,这样就保证了在long-term传播的时候也不会丢失。此外GRU相对于LSTM少了一个门函数,因此在参数的数量上也是要少于LSTM的,所以整体上GRU的训练速度要快于LSTM的。不过对于两个网络的好坏还是得看具体的应用场景。

目录
打赏
0
1
0
0
521
分享
相关文章
基于GRU网络的MQAM调制信号检测算法matlab仿真,对比LSTM
本研究基于MATLAB 2022a,使用GRU网络对QAM调制信号进行检测。QAM是一种高效调制技术,广泛应用于现代通信系统。传统方法在复杂环境下性能下降,而GRU通过门控机制有效提取时间序列特征,实现16QAM、32QAM、64QAM、128QAM的准确检测。仿真结果显示,GRU在低SNR下表现优异,且训练速度快,参数少。核心程序包括模型预测、误检率和漏检率计算,并绘制准确率图。
93 65
基于GRU网络的MQAM调制信号检测算法matlab仿真,对比LSTM
基于PSO粒子群优化的CNN-LSTM-SAM网络时间序列回归预测算法matlab仿真
本项目展示了基于PSO优化的CNN-LSTM-SAM网络时间序列预测算法。使用Matlab2022a开发,完整代码含中文注释及操作视频。算法结合卷积层提取局部特征、LSTM处理长期依赖、自注意力机制捕捉全局特征,通过粒子群优化提升预测精度。适用于金融市场、气象预报等领域,提供高效准确的预测结果。
基于GA遗传优化的CNN-LSTM-SAM网络时间序列回归预测算法matlab仿真
本项目使用MATLAB 2022a实现时间序列预测算法,完整程序无水印。核心代码包含详细中文注释和操作视频。算法基于CNN-LSTM-SAM网络,融合卷积层、LSTM层与自注意力机制,适用于金融市场、气象预报等领域。通过数据归一化、种群初始化、适应度计算及参数优化等步骤,有效处理非线性时间序列,输出精准预测结果。
基于WOA鲸鱼优化的CNN-LSTM-SAM网络时间序列回归预测算法matlab仿真
本内容介绍了一种基于CNN-LSTM-SAM网络与鲸鱼优化算法(WOA)的时间序列预测方法。算法运行于Matlab2022a,完整程序无水印并附带中文注释及操作视频。核心流程包括数据归一化、种群初始化、适应度计算及参数更新,最终输出最优网络参数完成预测。CNN层提取局部特征,LSTM层捕捉长期依赖关系,自注意力机制聚焦全局特性,全连接层整合特征输出结果,适用于复杂非线性时间序列预测任务。
基于CNN卷积神经网络的金融数据预测matlab仿真,对比BP,RBF,LSTM
本项目基于MATLAB2022A,利用CNN卷积神经网络对金融数据进行预测,并与BP、RBF和LSTM网络对比。核心程序通过处理历史价格数据,训练并测试各模型,展示预测结果及误差分析。CNN通过卷积层捕捉局部特征,BP网络学习非线性映射,RBF网络进行局部逼近,LSTM解决长序列预测中的梯度问题。实验结果表明各模型在金融数据预测中的表现差异。
212 10
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
从理论到实践:如何使用长短期记忆网络(LSTM)改善自然语言处理任务
【10月更文挑战第7天】随着深度学习技术的发展,循环神经网络(RNNs)及其变体,特别是长短期记忆网络(LSTMs),已经成为处理序列数据的强大工具。在自然语言处理(NLP)领域,LSTM因其能够捕捉文本中的长期依赖关系而变得尤为重要。本文将介绍LSTM的基本原理,并通过具体的代码示例来展示如何在实际的NLP任务中应用LSTM。
533 4
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
105 17
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
70 10

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等