LSTM(长短期记忆)网络的算法介绍及数学推导

简介: LSTM(长短期记忆)网络的算法介绍及数学推导
前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容不乏不准确的地方,希望批评指正,共同进步。

本文旨在说明LSTM正向传播及反向传播的算法及数学推导过程,其他内容CSDN上文章很多,不再赘述。因此在看本文前必须掌握以下两点基础知识:

①RNN的架构及算法:RNN作为LSTM的基础,是必须要先掌握的。

夹带私货,推荐自己的文章:基于Numpy构建RNN模块并进行实例应用(附代码)

②LSTM的架构:基于RNN引入上一时刻隐层输出的思想,LSTM又增加了细胞状态 C t C_t Ct的概念。 t t t时刻的输出除了要参考 t − 1 t-1 t1时刻隐层的输出 h t − 1 h_{t-1} ht1之外,还要参考 t − 1 t-1 t1时刻的细胞状态 C t − 1 C_{t-1} Ct1。为了计算细胞状态,引入忘记门、输出门、新记忆门、输出门几个路径。

推荐文章:如何从RNN起步,一步一步通俗理解LSTM 以及此篇文章中引用的文章,都值得好好看下。

基于colah的博客的LSTM结构图,稍微加工下得到下面的原理图:

一、LSTM正向传播算法

这块比较容易,只要严格按照上面原理图,正向传播的算法都容易得出。

1.隐藏层正向传播算法

t t t时刻各个门为:

  • 忘记门: image.png
  • 输入门: image.png
  • 新记忆门: image.png
  • 输出门: image.png

t t t时刻的细胞状态 C t C_t Ct为:

image.png

t t t时刻的隐层输出 h t h_t ht为:

image.png

σ \sigma σ为Sigmoid函数,⨀为矩阵的哈达马积。

2.输出层正向传播算法

t t t时刻的最终输出为:

image.png

二、LSTM的反向传播算法

重点,也是LTSM算法的难点来了。


※关于反向传播,始终要牢记其目的是:求解损失函数E关于各个权重的偏导。


既然有了正向传播的算法公式,那么反向传播就变成了一个求偏导的纯粹数学问题。下面以对忘记门的权重 w f w_f wf求偏导为例,讲解这个过程。

损失函数E对权重w f 的偏导为:

这里的E根据损失函数的选择而不同,例如交叉熵损失函数,即为:

image.png

可见这个偏导由3个部分组成:

1. 损失函数E对细胞状态 C t的偏导

首先我们要明白损失函数E是一个关于 image.png 的函数,即:

image.png

根据正向传播公式, h t 是 C t 的函数, C t是  Ct1的函数,即:

image.png

这样,求损失函数E对细胞状态 C t C_t Ct的偏导就成了高等数学中对复合函数求偏导的问题了。

代入上式,最终得出:

首先计算t = n时刻细胞状态的偏导,即E对C n 的偏导:

image.png

反向传播,再求E对C n−1的偏导:

image.png

反向传播,再求E对Cn−2 的偏导:

image.png

以此类推,容易得出t时刻E对C t的偏导:

image.png

根据正向传播公式,可以得出:

image.png

代入上式,最终得出:

实际上,上式的乘法“ · ”对于矩阵而言,都是哈达马积“⨀”。为了方便理解,均以单个变量而非矩阵的形式为例说明求偏导的过程,下面也是如此,不再特殊说明。

2. 细胞状态 C t对忘记门 f t的偏导

根据正向传播公式容易得出:

3. 忘记门 f t f_t ft对权重 w f w_f wf的偏导

根据正向传播公式容易得出:

对于Sigmoid函数及上面tanh函数的求导过程略,如果不会CSDN上也能找到具体过程。

最终得出:

至此,LSTM的正向传播及反向传播的过程推导结束。


后面预告下用Python实现它。


填坑了,Python实现LSTM的链接:基于NumPy构建LSTM模块并进行实例应用(附代码)


相关文章
|
4月前
|
传感器 机器学习/深度学习 算法
【UASNs、AUV】无人机自主水下传感网络中遗传算法的路径规划问题研究(Matlab代码实现)
【UASNs、AUV】无人机自主水下传感网络中遗传算法的路径规划问题研究(Matlab代码实现)
156 0
|
3月前
|
存储 机器学习/深度学习 监控
网络管理监控软件的 C# 区间树性能阈值查询算法
针对网络管理监控软件的高效区间查询需求,本文提出基于区间树的优化方案。传统线性遍历效率低,10万条数据查询超800ms,难以满足实时性要求。区间树以平衡二叉搜索树结构,结合节点最大值剪枝策略,将查询复杂度从O(N)降至O(logN+K),显著提升性能。通过C#实现,支持按指标类型分组建树、增量插入与多维度联合查询,在10万记录下查询耗时仅约2.8ms,内存占用降低35%。测试表明,该方案有效解决高负载场景下的响应延迟问题,助力管理员快速定位异常设备,提升运维效率与系统稳定性。
260 4
|
3月前
|
机器学习/深度学习 算法
采用蚁群算法对BP神经网络进行优化
使用蚁群算法来优化BP神经网络的权重和偏置,克服传统BP算法容易陷入局部极小值、收敛速度慢、对初始权重敏感等问题。
393 5
|
4月前
|
存储 算法 安全
即时通讯安全篇(三):一文读懂常用加解密算法与网络通讯安全
作为开发者,也会经常遇到用户对数据安全的需求,当我们碰到了这些需求后如何解决,如何何种方式保证数据安全,哪种方式最有效,这些问题经常困惑着我们。52im社区本次着重整理了常见的通讯安全问题和加解密算法知识与即时通讯/IM开发同行们一起分享和学习。
424 9
|
4月前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
335 2
|
3月前
|
机器学习/深度学习 人工智能 算法
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
【基于TTNRBO优化DBN回归预测】基于瞬态三角牛顿-拉夫逊优化算法(TTNRBO)优化深度信念网络(DBN)数据回归预测研究(Matlab代码实现)
198 0
|
4月前
|
算法 数据挖掘 区块链
基于遗传算法的多式联运车辆路径网络优优化研究(Matlab代码实现)
基于遗传算法的多式联运车辆路径网络优优化研究(Matlab代码实现)
155 2
|
5月前
|
机器学习/深度学习 算法 安全
【PSO-LSTM】基于PSO优化LSTM网络的电力负荷预测(Python代码实现)
【PSO-LSTM】基于PSO优化LSTM网络的电力负荷预测(Python代码实现)
308 0
|
7月前
|
机器学习/深度学习 算法 数据挖掘
基于WOA鲸鱼优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本项目基于MATLAB 2022a/2024b实现,采用WOA优化的BiLSTM算法进行序列预测。核心代码包含完整中文注释与操作视频,展示从参数优化到模型训练、预测的全流程。BiLSTM通过前向与后向LSTM结合,有效捕捉序列前后文信息,解决传统RNN梯度消失问题。WOA优化超参数(如学习率、隐藏层神经元数),提升模型性能,避免局部最优解。附有运行效果图预览,最终输出预测值与实际值对比,RMSE评估精度。适合研究时序数据分析与深度学习优化的开发者参考。
|
7月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于GA遗传优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
本内容包含基于BiLSTM与遗传算法(GA)的算法介绍及实现。算法通过MATLAB2022a/2024b运行,核心为优化BiLSTM超参数(如学习率、神经元数量),提升预测性能。LSTM解决传统RNN梯度问题,捕捉长期依赖;BiLSTM双向处理序列,融合前文后文信息,适合全局信息任务。附完整代码(含注释)、操作视频及无水印运行效果预览,适用于股票预测等场景,精度优于单向LSTM。

热门文章

最新文章