LSTM(长短期记忆)网络的算法介绍及数学推导

简介: LSTM(长短期记忆)网络的算法介绍及数学推导
前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容不乏不准确的地方,希望批评指正,共同进步。

本文旨在说明LSTM正向传播及反向传播的算法及数学推导过程,其他内容CSDN上文章很多,不再赘述。因此在看本文前必须掌握以下两点基础知识:

①RNN的架构及算法:RNN作为LSTM的基础,是必须要先掌握的。

夹带私货,推荐自己的文章:基于Numpy构建RNN模块并进行实例应用(附代码)

②LSTM的架构:基于RNN引入上一时刻隐层输出的思想,LSTM又增加了细胞状态 C t C_t Ct的概念。 t t t时刻的输出除了要参考 t − 1 t-1 t1时刻隐层的输出 h t − 1 h_{t-1} ht1之外,还要参考 t − 1 t-1 t1时刻的细胞状态 C t − 1 C_{t-1} Ct1。为了计算细胞状态,引入忘记门、输出门、新记忆门、输出门几个路径。

推荐文章:如何从RNN起步,一步一步通俗理解LSTM 以及此篇文章中引用的文章,都值得好好看下。

基于colah的博客的LSTM结构图,稍微加工下得到下面的原理图:

一、LSTM正向传播算法

这块比较容易,只要严格按照上面原理图,正向传播的算法都容易得出。

1.隐藏层正向传播算法

t t t时刻各个门为:

  • 忘记门: image.png
  • 输入门: image.png
  • 新记忆门: image.png
  • 输出门: image.png

t t t时刻的细胞状态 C t C_t Ct为:

image.png

t t t时刻的隐层输出 h t h_t ht为:

image.png

σ \sigma σ为Sigmoid函数,⨀为矩阵的哈达马积。

2.输出层正向传播算法

t t t时刻的最终输出为:

image.png

二、LSTM的反向传播算法

重点,也是LTSM算法的难点来了。


※关于反向传播,始终要牢记其目的是:求解损失函数E关于各个权重的偏导。


既然有了正向传播的算法公式,那么反向传播就变成了一个求偏导的纯粹数学问题。下面以对忘记门的权重 w f w_f wf求偏导为例,讲解这个过程。

损失函数E对权重w f 的偏导为:

这里的E根据损失函数的选择而不同,例如交叉熵损失函数,即为:

image.png

可见这个偏导由3个部分组成:

1. 损失函数E对细胞状态 C t的偏导

首先我们要明白损失函数E是一个关于 image.png 的函数,即:

image.png

根据正向传播公式, h t 是 C t 的函数, C t是  Ct1的函数,即:

image.png

这样,求损失函数E对细胞状态 C t C_t Ct的偏导就成了高等数学中对复合函数求偏导的问题了。

代入上式,最终得出:

首先计算t = n时刻细胞状态的偏导,即E对C n 的偏导:

image.png

反向传播,再求E对C n−1的偏导:

image.png

反向传播,再求E对Cn−2 的偏导:

image.png

以此类推,容易得出t时刻E对C t的偏导:

image.png

根据正向传播公式,可以得出:

image.png

代入上式,最终得出:

实际上,上式的乘法“ · ”对于矩阵而言,都是哈达马积“⨀”。为了方便理解,均以单个变量而非矩阵的形式为例说明求偏导的过程,下面也是如此,不再特殊说明。

2. 细胞状态 C t对忘记门 f t的偏导

根据正向传播公式容易得出:

3. 忘记门 f t f_t ft对权重 w f w_f wf的偏导

根据正向传播公式容易得出:

对于Sigmoid函数及上面tanh函数的求导过程略,如果不会CSDN上也能找到具体过程。

最终得出:

至此,LSTM的正向传播及反向传播的过程推导结束。


后面预告下用Python实现它。


填坑了,Python实现LSTM的链接:基于NumPy构建LSTM模块并进行实例应用(附代码)


相关文章
|
23天前
|
机器学习/深度学习 资源调度 算法
图卷积网络入门:数学基础与架构设计
本文系统地阐述了图卷积网络的架构原理。通过简化数学表述并聚焦于矩阵运算的核心概念,详细解析了GCN的工作机制。
59 3
图卷积网络入门:数学基础与架构设计
|
4月前
|
机器学习/深度学习 API 异构计算
7.1.3.2、使用飞桨实现基于LSTM的情感分析模型的网络定义
该文章详细介绍了如何使用飞桨框架实现基于LSTM的情感分析模型,包括网络定义、模型训练、评估和预测的完整流程,并提供了相应的代码实现。
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
|
2月前
|
机器学习/深度学习 数据可视化
KAN干翻MLP,开创神经网络新范式!一个数十年前数学定理,竟被MIT华人学者复活了
【10月更文挑战第12天】MIT华人学者提出了一种基于Kolmogorov-Arnold表示定理的新型神经网络——KAN。与传统MLP不同,KAN将可学习的激活函数放在权重上,使其在表达能力、准确性、可解释性和收敛速度方面表现出显著优势,尤其在处理高维数据时效果更佳。然而,KAN的复杂性也可能带来部署和维护的挑战。论文地址:https://arxiv.org/pdf/2404.19756
58 1
|
2月前
|
机器学习/深度学习 存储 自然语言处理
从理论到实践:如何使用长短期记忆网络(LSTM)改善自然语言处理任务
【10月更文挑战第7天】随着深度学习技术的发展,循环神经网络(RNNs)及其变体,特别是长短期记忆网络(LSTMs),已经成为处理序列数据的强大工具。在自然语言处理(NLP)领域,LSTM因其能够捕捉文本中的长期依赖关系而变得尤为重要。本文将介绍LSTM的基本原理,并通过具体的代码示例来展示如何在实际的NLP任务中应用LSTM。
181 4
|
6月前
|
机器学习/深度学习 存储 自然语言处理
程序与技术分享:DeepMemoryNetwork深度记忆网络
程序与技术分享:DeepMemoryNetwork深度记忆网络
|
2月前
|
机器学习/深度学习 人工智能 Rust
MindSpore QuickStart——LSTM算法实践学习
MindSpore QuickStart——LSTM算法实践学习
52 2
|
4月前
|
机器学习/深度学习 数据采集 数据可视化
【优秀python系统毕设】基于Python flask的气象数据可视化系统设计与实现,有LSTM算法预测气温
本文介绍了一个基于Python Flask框架开发的气象数据可视化系统,该系统集成了数据获取、处理、存储、LSTM算法气温预测以及多种数据可视化功能,旨在提高气象数据的利用价值并推动气象领域的发展。
259 1
|
4月前
|
机器学习/深度学习
【机器学习】面试题:LSTM长短期记忆网络的理解?LSTM是怎么解决梯度消失的问题的?还有哪些其它的解决梯度消失或梯度爆炸的方法?
长短时记忆网络(LSTM)的基本概念、解决梯度消失问题的机制,以及介绍了包括梯度裁剪、改变激活函数、残差结构和Batch Normalization在内的其他方法来解决梯度消失或梯度爆炸问题。
191 2
|
4月前
|
机器学习/深度学习 数据采集 存储
基于Python+flask+echarts的气象数据采集与分析系统,可实现lstm算法进行预测
本文介绍了一个基于Python、Flask和Echarts的气象数据采集与分析系统,该系统集成了LSTM算法进行数据预测,并提供了实时数据监测、历史数据查询、数据可视化以及用户权限管理等功能。
132 0