seq2seq与Attention机制(一)

简介: seq2seq与Attention机制(一)

学习目标



  • 目标
  • 掌握seq2seq模型特点
  • 掌握集束搜索方式
  • 掌握BLEU评估方法
  • 掌握Attention机制


  • 应用
  • 应用Keras实现seq2seq对日期格式的翻译


4.3.1 seq2seq



seq2seq模型是在2014年,是由Google Brain团队和Yoshua Bengio 两个团队各自独立的提出来。


4.3.1.1 定义


seq2seq是一个Encoder–Decoder 结构的网络,它的输入是一个序列,输出也是一个序列, Encoder 中将一个可变长度的信号序列变为固定长度的向量表达,Decoder 将这个固定长度的向量变成可变长度的目标的信号序列。


image.png


注:Cell可以用 RNN ,GRU,LSTM 等结构。


相当于将RNN模型当中的s^{0}s0输入变成一个encoder


4.3.1.2 条件语言模型理解


  • 1、编解码器作用


  • 编码器的作用是把一个不定长的输入序列x_{1},\ldots,x_{t}x1,…,xt,输出到一个编码状态CC
  • 解码器输出y^{t}yt的条件概率将基于之前的输出序列y^{1}, y^{t-1}y1,yt−1和编码状态CC


argmax {P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T)argmaxP(y1,…,yT′∣x1,…,xT),给定输入的序列,使得输出序列的概率值最大。


  • 2、根据最大似然估计,最大化输出序列的概率


{P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T) = \prod_{t'=1}^{T'} {P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, x_1, \ldots, x_T) = \prod_{t'=1}^{T'} {P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, {C})P(y1,…,yT′∣x1,…,xT)=∏t′=1T′P(yt′∣y1,…,yt′−1,x1,…,xT)=∏t′=1T′P(yt′∣y1,…,yt′−1,C)


由于这个公式需要求出:P(y^{1}|C) * P(y^{2}|y^{1},C)*P(y^{3}|y^{2},y^{2},y^{1},C)...P(y1∣C)∗P(y2∣y1,C)∗P(y3∣y2,y2,y1,C)...这个概率连乘会非常非常小不利于计算存储,所以需要对公式取对数计算:


\log{P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T) = \sum_{t'=1}^{T'} \log{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, {C})logP(y1,…,yT′∣x1,…,xT)=∑t′=1T′logP(yt′∣y1,…,yt′−1,C)


所以这样就变成了P(y^{1}|C)+ P(y^{2}|y^{1},C)+P(y^{3}|y^{2},y^{2},y^{1},C)...P(y1∣C)+P(y2∣y1,C)+P(y3∣y2,y2,y1,C)...概率相加。这样也可以看成输出结果通过softmax就变成了概率最大,而损失最小的问题,输出序列损失最小化。


4.3.1.3 应用场景


神经机器翻译(NMT)


image.png


聊天机器人

接下来我们来看注意力机制,那么普通的seq2seq会面临什么样的问题?


4.3.2 注意力机制



4.3.2.1 长句子问题


image.png


对于更长的句子,seq2seq就显得力不从心了,无法做到准确的翻译,一下是通常BLEU的分数随着句子的长度变化,可以看到句子非常长的时候,分数就很低。


image.png


本质原因:在Encoder-Decoder结构中,Encoder把所有的输入序列都编码成一个统一的语义特征CC再解码,因此, CC中必须包含原始序列中的所有信息,它的长度就成了限制模型性能的瓶颈。当要翻译的句子较长时,一个CC可能存不下那么多信息,就会造成翻译精度的下降。


4.3.2.2 定义


  • 建立Encoder的隐层状态输出到Decoder对应输出y所需要的上下文信息
  • 目的:增加编码器信息输入到解码器中相同时刻的联系,其它时刻信息减弱


image.png


4.3.2.3 公式


注意上述的几个细节,颜色的连接深浅不一样,假设Encoder的时刻记为tt,而Decoder的时刻记为t^{'}t′。


1、{c}_{t'} = \sum_{t=1}^T \alpha_{t' t}{h}_tct′=∑t=1Tαt′tht


  • \alpha_{t^{'}t}αt′t为参数,在网络中训练得到


  • 理解:蓝色的解码器中的cell举例子


  • \alpha_{4}{1}h_{1}+\alpha_{4}{2}h_{2} + \alpha_{4}{3}h_{3} + \alpha_{4}{4}h_{4} = c_{4}α41h1+α42h2+α43h3+α44h4=c4


2、\alpha_{t^{'}t}αt′t的N个权重系数由来?


  • 权重系数通过softmax计算:\alpha_{t' t} = \frac{\exp(e_{t' t})}{ \sum_{k=1}^T \exp(e_{t' k}) },\quad t=1,\ldots,Tαt′t=∑k=1Texp(et′k)exp(et′t),t=1,…,T


  • e_{t' t} = g({s}_{t' - 1}, {h}_t)= {v}^\top \tanh({W}_s {s} + {W}_h {h})et′t=g(st′−1,ht)=v⊤tanh(Wss+Whh)


  • e_{t' t}et′t是由t时刻的编码器隐层状态输出和解码器t^{'}-1t′−1时刻的隐层状态输出计算出来的
  • ss为解码器隐层状态输出,hh为编码器隐层状态输出
  • v,W_{s},W_{h}v,Ws,Wh都是网络学习的参数


image.png


4.3.3 机器翻译案例



4.3.3.1 问题


使用简单的“日期转换”任务代替翻译任务,为了不然训练时间变得太长。


网络将输入以各种可能格式(例如“1958年8月29日”,“03/30/1968”,“1987年6月24日”)编写的日期,并将其翻译成标准化的机器可读日期(例如“1958 -08-29“,”1968-03-30“,”1987-06-24“)。使用seq2seq网络学习以通用机器可读格式YYYY-MM-DD输出日期。


4.3.3.2 相关环境与结果演示


pip install faker
pip install tqdm
pip install babel
pip install keras==2.2.4


  • faker:生成数据包
  • tqdm:python扩展包
  • babel:代码装换器
  • keras:更加方便简洁的深度学习库
  • 为了快速编写代码


4.3.3.4 代码分析


  • Seq2seq():


  • 序列模型类
  • load_data(self,m):加载数据类,选择加载多少条数据
  • init_seq2seq(self):初始化模型,需要自定义自己的模型
  • self.get_encoder(self):定义编码器
  • self.get_decoder(self):定义解码器
  • self.get_attention(self):定义注意力机制
  • self.get_output_layer(self):定义解码器输出层
  • model(self):定义模型整体输入输出逻辑
  • train(self, X_onehot, Y_onehot):训练模型
  • test(self):测试模型


  • 训练


if __name__ == '__main__':
    s2s = Seq2seq()
    X_onehot, Y_onehot = s2s.load_data(10000)
    s2s.init_seq2seq()
    s2s.train(X_onehot, Y_onehot)


整个数据集特征值的形状: (10000, 30, 37)
整个数据集目标值的形状: (10000, 10, 11)
查看第一条数据集格式:特征值:9 may 1998, 目标值: 1998-05-09
[12  0 24 13 34  0  4 12 12 11 36 36 36 36 36 36 36 36 36 36 36 36 36 36
 36 36 36 36 36 36] [ 2 10 10  9  0  1  6  0  1 10]
one_hot编码: [[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 1.]] [[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
Epoch 1/1
  100/10000 [..............................] - ETA: 10:52 - loss: 23.9884 - dense_1_loss: 2.3992 - dense_1_acc: 0.3200 - dense_1_acc_1: 0.0000e+00 - dense_1_acc_2: 0.0100 - dense_1_acc_3: 0.1300 - dense_1_acc_4: 0.0000e+00 - dense_1_acc_5: 0.0400 - dense_1_acc_6: 0.0900 - dense_1_acc_7: 0.0000e+00 - dense_1_acc_8: 0.3500 - dense_1_acc_9: 0.1100
  200/10000 [..............................] - ETA: 5:27 - loss: 23.9289 - dense_1_loss: 2.3991 - dense_1_acc: 0.2550 - dense_1_acc_1: 0.0000e+00 - dense_1_acc_2: 0.0050 - dense_1_acc_3: 0.1150 - dense_1_acc_4: 0.0950 - dense_1_acc_5: 0.0250 - dense_1_acc_6: 0.1150 - dense_1_acc_7: 0.0800 - dense_1_acc_8: 0.3400 - dense_1_acc_9: 0.1050


测试


if __name__ == '__main__':
    s2s = Seq2seq()
    X_onehot, Y_onehot = s2s.load_data(10000)
    s2s.init_seq2seq()
    s2s.train(X_onehot, Y_onehot)
    # s2s.test()


source: 1 March 2001
output: 2001-03-01
目录
相关文章
|
机器学习/深度学习 自然语言处理 语音技术
从 Seq2Seq 到 Attention:彻底改变序列建模
从 Seq2Seq 到 Attention:彻底改变序列建模
73 0
|
5月前
|
机器学习/深度学习 自然语言处理
序列到序列(Seq2Seq)模型
序列到序列(Seq2Seq)模型
227 8
|
5月前
|
机器学习/深度学习 自然语言处理
seq2seq的机制原理
【8月更文挑战第1天】seq2seq的机制原理。
38 1
|
8月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
seq2seq:中英文翻译
seq2seq:中英文翻译
61 1
|
8月前
|
机器学习/深度学习 存储 自然语言处理
NLP中的RNN、Seq2Seq与attention注意力机制(下)
NLP中的RNN、Seq2Seq与attention注意力机制(下)
69 1
|
8月前
|
机器学习/深度学习 存储 自然语言处理
NLP中的RNN、Seq2Seq与attention注意力机制(上)
NLP中的RNN、Seq2Seq与attention注意力机制
72 1
|
8月前
|
机器学习/深度学习 人工智能 自然语言处理
详细介绍Seq2Seq、Attention、Transformer !!
详细介绍Seq2Seq、Attention、Transformer !!
173 0
|
机器学习/深度学习 存储 自然语言处理
深入解析序列模型:全面阐释 RNN、LSTM 与 Seq2Seq 的秘密
深入解析序列模型:全面阐释 RNN、LSTM 与 Seq2Seq 的秘密
172 0
|
机器学习/深度学习 自然语言处理 文字识别
初步了解RNN, Seq2Seq, Attention注意力机制
初步了解RNN, Seq2Seq, Attention注意力机制
137 0
初步了解RNN, Seq2Seq, Attention注意力机制
|
机器学习/深度学习 自然语言处理 PyTorch
【文本摘要(3)】Pytorch之Seq2seq: attention
【文本摘要(3)】Pytorch之Seq2seq: attention
104 0