【文本分类】Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification

简介: 【文本分类】Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification

·摘要:

 从模型的角度,本文作者将RNN(Bi-LSTM)和attention mechanism结合使用,提出AttRNN模型,应用到了NLP的关系抽取(Relation Classification)中,也可应用到文本分类任务中,提高精度。

·参考文献:

 [1] Attention-Based Bidirectional Long Short-Term Memory Networks for

Relation Classification 论文链接:https://aclanthology.org/P16-2034.pdf

[1] 摘要


  · 重要的信息可能出现在句子的任何位置。为了解决这些问题,提出基于注意力机制的双向长短期记忆网络(AttBiLSTM)来捕获句子中最重要的语义信息

  简单的理解就是,给句子向量乘上一个权重向量,按权重向量重新计算向量值。

[2] 模型


image.png

  模型一共有6层,输入层、嵌入层、双向LSTM层、注意力机制层、全连接层、输出层。

   双向LSTM的输出为2倍(正反两个反向)的[h1, h2,…hT]。普通RNN模型,就会把此处双向LSTM的输出作为全连接层的输入进行分类,在本文中还需经过注意力层。

   注意力机制层的作用是找到一个句子中各个词的相关系数,然后把原来句子向量乘上这个系数。计算公式为:

image.png

 H HH是Bi-LSTM层的输出;H HH经过激活函数后变成M;w ww是一个可优化的一维张量数组相等,维度与H HH的最后一个维度,即Bi-LSTM层的hidden_size * 2; α \alphaα即为注意力权重系数,表示一个句子中的词语之间的相关性;r rr则为Bi-LSTM输出H HH经过加权求和后的结果;最后通过t a n h tanhtanh激活函数生成表征向量 h ∗ = t a n h ( r ) h^*=tanh(r)h ∗ =tanh(r);

 关于注意力机制,可以参考下面两篇文章:

  https://zhuanlan.zhihu.com/p/65304158

  https://zhuanlan.zhihu.com/p/393940472

[3] 代码复现


  贴出基础模型:

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.tanh1 = nn.Tanh()
        # self.u = nn.Parameter(torch.Tensor(config.hidden_size * 2, config.hidden_size * 2))
        self.w = nn.Parameter(torch.zeros(config.hidden_size * 2))
        self.tanh2 = nn.Tanh()
        self.fc1 = nn.Linear(config.hidden_size * 2, config.hidden_size2)
        self.fc = nn.Linear(config.hidden_size2, config.num_classes)
    def forward(self, x):
        x, _ = x
        emb = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        H, _ = self.lstm(emb)  # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256]
        M = self.tanh1(H)  # [128, 32, 256]
        # M = torch.tanh(torch.matmul(H, self.u))
        alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1)  # [128, 32, 1]
        out = H * alpha  # [128, 32, 256]
        out = torch.sum(out, 1)  # [128, 256]
        out = F.relu(out)
        out = self.fc1(out)
        out = self.fc(out)  # [128, 64]
        return out

 实验结果(baseline):

数据集 RNN RCNN AttRNN
THUCNews 90.73% 91.21% 90.62%



相关文章
|
机器学习/深度学习 自然语言处理 TensorFlow
Long Short-Term Memory,简称 LSTM
长短期记忆(Long Short-Term Memory,简称 LSTM)是一种特殊的循环神经网络(RNN)结构,用于处理序列数据,如语音识别、自然语言处理、视频分析等任务。LSTM 网络的主要目的是解决传统 RNN 在训练过程中遇到的梯度消失和梯度爆炸问题,从而更好地捕捉序列数据中的长期依赖关系。
146 4
|
机器学习/深度学习 数据挖掘
【多标签文本分类】Balancing Methods for Multi-label Text Classification with Long-Tailed Class Distribution
【多标签文本分类】Balancing Methods for Multi-label Text Classification with Long-Tailed Class Distribution
154 0
【多标签文本分类】Balancing Methods for Multi-label Text Classification with Long-Tailed Class Distribution
JDK源码(11)-Long、Short
JDK源码(11)-Long、Short
132 0
java中整型数据(byte、short、int、long)溢出的现象及原理
java中整型数据(byte、short、int、long)溢出的现象及原理
|
机器学习/深度学习
3_Long Short Term Memory (LSTM)
3_Long Short Term Memory (LSTM)
150 0
3_Long Short Term Memory (LSTM)
|
机器学习/深度学习 算法框架/工具
(zhuan) Attention in Long Short-Term Memory Recurrent Neural Networks
Attention in Long Short-Term Memory Recurrent Neural Networks by Jason Brownlee on June 30, 2017 in Deep Learning   The Encoder-Decoder architecture i...
|
Linux C语言 存储
printf中的short int, int, long int和long long int
hd: short int d: int ld: long int lld: long long int Linux基本数据类型大小——int,char,long int,long long int 在Linux操作系统下使用GCC进行编程,目前一般处理器为32位字宽,下面是/usr/include/limit.h文件对Linux下数据类型的限制及存储字节大小的说明。
1197 0
|
SQL Oracle 关系型数据库
long sort 和 short sort
long sort 和 short sort转自 http://www.itpub.net/thread-1266906-1-1.html对这个帖子 http://www.itpub.net/thread-1266765-1-1.html的SQL做了点测试,顺便发现oracle 10g对排序这个操作还是很有点门道值得我们研究的。
1023 0
|
9月前
|
JSON JavaScript 前端开发
解决js中Long类型数据在请求与响应过程精度丢失问题(springboot项目中)
解决js中Long类型数据在请求与响应过程精度丢失问题(springboot项目中)
749 0