NLP学习笔记(二) LSTM基本介绍

简介: NLP学习笔记(二) LSTM基本介绍

前言


大家好,我是半虹,这篇文章来讲长短期记忆网络 (Long Short-Term Memory, LSTM)

文章行文思路如下:

  1. 首先通过循环神经网络引出为啥需要长短期记忆网络
  2. 然后介绍长短期记忆网络的核心思想与运作方式
  3. 最后通过简短的代码深入理解长短期记忆网络的运作方式



正文


长短期记忆网络可以看作是循环神经网络的改进版本,想要理解长短期记忆网络,首先要了解循环神经网络


由于我们之前已详细介绍过循环神经网络,所以这里我们只会做一个简单的回顾,想看详细的说明请戳这里



对比前馈神经网络,循环神经网络通过增加隐状态实现对隐藏层信息的传递,以此达到记住历史输入的目的


网络在每个时间步里读取上一隐藏层输出作为当前隐藏层输入,并保存当前隐藏层输出作为下一隐藏层输入


其结构简图如下:

啊啊啊.png其中 X是输入 ,H 是隐藏层的输出,图中的每个矩形都表示同一个循环神经网络隐藏层


下面我们把隐藏层中的细节也画出来,方便后面与长短期记忆网络来对比


森斯.png

其中 X 是输入 ,H 是隐藏层的输出,图中的灰色矩形同样代表隐藏层,σ \sigmaσ 表示一个带激活函数的线性层

image.png

image.png

理论上,上述介绍的循环神经网络能处理任意长的序列,但实际上却并非如此


在实际应用循环神经网络处理长序列时通常会出现梯度爆炸或梯度消失的情况,导致网络难以捕捉长期依赖


这是为什么呢?通过简单分析一下梯度计算公式就能发现端倪


为了阐述方便,我们暂且假定所有的参数都是一维的,用字母 θ \thetaθ 表示,对参数求导并按时间展开后如下所示


image.png

image.png

这说明了什么?这说明了对于当前输入,距其更远的输入的梯度更容易出现梯度爆炸或梯度消失

从而导致长距离的梯度反馈失效,这就是循环神经网络难以捕捉长期依赖的实际含义

image.png

总结一下,梯度反向传播时发生的异常,主要可以分为两种,一是梯度爆炸,二是梯度消失


梯度爆炸比较容易处理,一个简单但有效的做法是设置一个梯度阈值,当梯度超过这个阈值时直接截断


梯度消失更难处理一些,而现在流行的做法正是将循环神经网络替换成长短期记忆网络


注意,长短期记忆网络能缓解梯度消失的问题,但并不能缓解梯度爆炸的问题


上面我们从反向传播的角度解释了什么是梯度消失


如果我们从前向计算的角度来看,则梯度消失可以理解成隐状态对短期记忆敏感,对长期记忆作用有限


为了维持长期记忆,长短期记忆网络引入记忆元存放长期记忆,并通过门机制控制记忆元中的信息流动


从直觉上来说,先前重要的记忆会保留在记忆元,不重要的记忆会被过滤,以此来达到长期记忆的目的


这里有两个概念需要解释,一是记忆元,二是门机制,这两个就是长短期记忆网络的核心


先说记忆元,可以理解成另一种隐状态,都是用来记录附加信息的,简称为单元,英文为 Cell \text{Cell}Cell


再说门机制,这是用来控制记忆元中信息流动的机制,具体来说包括三个控制门:


输入门:控制是否将信息写入记忆元,英文为 Input Gate \text{Input Gate}Input Gate

遗忘门:控制是否从记忆元丢弃信息,英文为 Forget Gate \text{Forget Gate}Forget Gate

输出门:控制是否从记忆元读出信息,英文为 Output Gate \text{Output Gate}Output Gate


本质上来说,上述三个控制门都是由一个线性层加一个激活函数组成的,这里激活函数用的是 sigmoid \text{sigmoid}sigmoid


因为这样能将输出限制在零到一之间,以表示门的打开程度,控制信息流动的程度



相比循环神经网络只有一个传输状态,即隐状态,长短期记忆网络有两个传输状态,即隐状态和记忆元


二者的输入输出对比图如下:

经济.png





0817d35e6bb0345c63928c931c80d26.jpg


为了帮助大家进一步理解长短期记忆网络的工作方式,下面我们举一个例子来说,并给出关键代码

假设我们用长短期记忆网络对下面这个句子进行编码:我在画画

import torch
import torch.nn as nn
# 定义输入数据
# 对于输入句子我在画画,首先用独热编码得到其向量表示
x1 = torch.tensor([1, 0, 0]).float() # 我
x2 = torch.tensor([0, 1, 0]).float() # 在
x3 = torch.tensor([0, 0, 1]).float() # 画
x4 = torch.tensor([0, 0, 1]).float() # 画
h0 = torch.zeros(5) # 初始化隐状态
c0 = torch.zeros(5) # 初始化记忆元
# 定义模型参数
# 模型的输入是三维向量,这里定义模型的输出是五维向量
W_xi = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hi = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_i  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xf = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hf = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_f  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xo = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_ho = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_o  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xc = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hc = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_c  = nn.Parameter(torch.randn(5)   , requires_grad = True)
# 前向传播
def forward(X, H, C):
    # 计算各种门机制
    I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i) # 输入门
    F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f) # 遗忘门
    O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o) # 输出门
    # 计算候选记忆元
    C_tilde = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
    # 计算当前记忆元
    C = F * C + I * C_tilde
    # 计算当前隐状态
    H = O * C.tanh()
    # 返回结果
    return H, C
h1, c1 = forward(x1, h0, c0)
h2, c2 = forward(x2, h1, c1)
h3, c3 = forward(x3, h2, c2)
h4, c4 = forward(x4, h3, c3)
# 结果输出
print(h3) # tensor([-0.0408,  0.1785,  0.0455,  0.3802,  0.0235])
print(h4) # tensor([-0.0560,  0.1269,  0.0346,  0.3426,  0.0118])

最后提醒大家一点,如果长短期记忆网络后有接其他网络,例如后面接一个线性层做单词预测


那么通常不会用记忆元的输出,而是用隐藏层的输出



至此本文结束,要点总结如下:


循环神经网络在处理长序列时很容易会出现梯度爆炸和梯度消失的情况,导致网络难以捕捉长期依赖

对于梯度爆炸,通常可以采用梯度裁剪解决,对于梯度消失,可以采用长短期记忆网络缓解


除了有隐状态,长短期记忆网络还增加记忆元存放长期记忆,并通过门机制控制记忆元中的信息流动


目录
相关文章
|
15天前
|
机器学习/深度学习 自然语言处理 数据可视化
数据代码分享|PYTHON用NLP自然语言处理LSTM神经网络TWITTER推特灾难文本数据、词云可视化
数据代码分享|PYTHON用NLP自然语言处理LSTM神经网络TWITTER推特灾难文本数据、词云可视化
|
5月前
|
机器学习/深度学习 自然语言处理 机器人
【Tensorflow+自然语言处理+LSTM】搭建智能聊天客服机器人实战(附源码、数据集和演示 超详细)
【Tensorflow+自然语言处理+LSTM】搭建智能聊天客服机器人实战(附源码、数据集和演示 超详细)
217 0
|
自然语言处理 算法
NLP学习笔记(十) 分词(下)
NLP学习笔记(十) 分词(下)
103 0
|
机器学习/深度学习 自然语言处理
NLP学习笔记(八) GPT简明介绍 下
NLP学习笔记(八) GPT简明介绍
131 0
|
自然语言处理
NLP学习笔记(八) GPT简明介绍 上
NLP学习笔记(八) GPT简明介绍
112 0
|
自然语言处理
NLP学习笔记(七) BERT简明介绍 下
NLP学习笔记(七) BERT简明介绍
151 0
NLP学习笔记(七) BERT简明介绍 下
|
机器学习/深度学习 自然语言处理
NLP学习笔记(七) BERT简明介绍 上
NLP学习笔记(七) BERT简明介绍
93 0
|
机器学习/深度学习 自然语言处理 计算机视觉
NLP学习笔记(六) Transformer简明介绍
NLP学习笔记(六) Transformer简明介绍
142 0
|
机器学习/深度学习 自然语言处理
NLP学习笔记(五) 注意力机制
NLP学习笔记(五) 注意力机制
114 0
|
机器学习/深度学习 自然语言处理
NLP学习笔记(四) Seq2Seq基本介绍
NLP学习笔记(四) Seq2Seq基本介绍
120 0