NLP学习笔记(三) GRU基本介绍

简介: NLP学习笔记(三) GRU基本介绍

前言


大家好,我是半虹,这篇文章来讲门控循环单元 (Gated Recurrent Unit, GRU)

文章行文思路如下:

  1. 首先通过长短期记忆网络引出为什么需要门控循环单元
  2. 然后介绍门控循环单元的核心思想与运作方式
  3. 最后通过简洁的代码深入理解门控循环单元的运作方式


正文


在之前的文章中,我们已经介绍过循环神经网络和长短期记忆网络


知道了长短期记忆网络是为了缓解循环神经网络容易出现梯度消失的情况而设计的


然而,长短期记忆网络的参数确实有点多,计算速度也是有点慢,所以后来就有人提出了门控循环单元


门控循环单元与长短期记忆网络效果相当,但是其参数更少,且计算速度更快



对比长短期记忆网络,门控循环单元去除了记忆元,但仍保留了门机制,只不过门机制的种类稍有不同


以下是循环神经网络、长短期记忆网络、门控循环单元三者的对比

网络 是否有记忆元 传递状态 是否有门机制 门机制的种类
循环神经网络 隐状态
长短期记忆网络 隐状态、记忆元 输入门、遗忘门、输出门
门控循环单元 隐状态 重置门、更新门


我们发现,门控循环单元仅在隐状态上就能实现对长期记忆的控制


这是怎么做到的呢?其核心就在于门机制,通过门机制控制隐状态中的信息流动


从直觉上来说,先前重要的记忆会保留在隐状态,不重要的记忆会被过滤,以此达到长期记忆的目的



门控循环单元中的门机制包括两类:


重置门:用于控制记住多少旧状态,英文为 Reset Gate \text{Reset Gate}Reset Gate

更新门:用于控制新旧状态的占比,英文为 Update Gate \text{Update Gate}Update Gate


实际上,所谓的门机制,就是一个带激活函数的线性层而已,且激活函数通常会用 sigmoid \text{sigmoid}sigmoid


因为这样能将输出限制在零到一之间,以表示门的打开程度,控制信息流动的程度


0cd6cafc96e04d22ca365a4e612c89f.jpg


最后我们来简单实现一下门控循环单元

作为例子,我们用这个门控循环单元对以下句子进行编码:我在画画

import torch
import torch.nn as nn
# 定义输入数据
# 对于输入句子我在画画,首先用独热编码得到其向量表示
x1 = torch.tensor([1, 0, 0]).float() # 我
x2 = torch.tensor([0, 1, 0]).float() # 在
x3 = torch.tensor([0, 0, 1]).float() # 画
x4 = torch.tensor([0, 0, 1]).float() # 画
h0 = torch.zeros(5) # 初始化隐状态
# 定义模型参数
# 模型的输入是三维向量,这里定义模型的输出是五维向量
W_xr = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hr = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_r  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xz = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hz = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_z  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xh = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hh = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_h  = nn.Parameter(torch.randn(5)   , requires_grad = True)
# 前向传播
def forward(X, H):
    # 计算各种门机制
    R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r) # 重置门
    Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z) # 更新门
    # 计算候选隐状态
    H_tilde = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(R * H, W_hh) + b_h)
    # 计算当前隐状态
    H = Z * H + (1 - Z) * H_tilde
    # 返回结果
    return H
h1 = forward(x1, h0)
h2 = forward(x2, h1)
h3 = forward(x3, h2)
h4 = forward(x4, h3)
# 结果输出
print(h3) # tensor([ 0.7936, -0.9788,  0.8360,  0.2307, -0.9928])
print(h4) # tensor([ 0.8460, -0.9946,  0.9130,  0.0313, -0.9986])


至此本文结束,要点总结如下:

  1. 门控循环单元与长短期记忆网络效果相当,但是其参数更少,且计算速度更快
  2. 门控循环单元通过门机制,仅在隐状态上就能实现对长期记忆的控制



目录
相关文章
|
自然语言处理 算法
NLP学习笔记(十) 分词(下)
NLP学习笔记(十) 分词(下)
103 0
|
机器学习/深度学习 自然语言处理
NLP学习笔记(八) GPT简明介绍 下
NLP学习笔记(八) GPT简明介绍
131 0
|
自然语言处理
NLP学习笔记(八) GPT简明介绍 上
NLP学习笔记(八) GPT简明介绍
112 0
|
自然语言处理
NLP学习笔记(七) BERT简明介绍 下
NLP学习笔记(七) BERT简明介绍
151 0
NLP学习笔记(七) BERT简明介绍 下
|
机器学习/深度学习 自然语言处理
NLP学习笔记(七) BERT简明介绍 上
NLP学习笔记(七) BERT简明介绍
93 0
|
机器学习/深度学习 自然语言处理 计算机视觉
NLP学习笔记(六) Transformer简明介绍
NLP学习笔记(六) Transformer简明介绍
142 0
|
机器学习/深度学习 自然语言处理
NLP学习笔记(五) 注意力机制
NLP学习笔记(五) 注意力机制
114 0
|
机器学习/深度学习 自然语言处理
NLP学习笔记(四) Seq2Seq基本介绍
NLP学习笔记(四) Seq2Seq基本介绍
120 0
|
2月前
|
机器学习/深度学习 自然语言处理
利用深度学习技术改进自然语言处理中的命名实体识别
命名实体识别(Named Entity Recognition, NER)在自然语言处理领域扮演着重要角色,但传统方法在处理复杂语境和多样化实体时存在局限性。本文将探讨如何利用深度学习技术,特别是基于预训练模型的方法,来改进命名实体识别,提高其在现实场景中的性能和适用性。
|
2月前
|
机器学习/深度学习 自然语言处理 监控
利用深度学习技术实现自然语言处理中的情感分析
本文将深入探讨如何利用深度学习技术在自然语言处理领域中实现情感分析。通过介绍情感分析的背景和原理,结合深度学习模型如LSTM、BERT等的应用,帮助读者了解情感分析的重要性以及如何利用最新技术实现更准确的情感识别。