学习目标
- 目标
- 掌握seq2seq模型特点
- 掌握集束搜索方式
- 掌握BLEU评估方法
- 掌握Attention机制
- 应用
- 应用Keras实现seq2seq对日期格式的翻译
4.3.1 seq2seq
seq2seq模型是在2014年,是由Google Brain团队和Yoshua Bengio 两个团队各自独立的提出来。
4.3.1.1 定义
seq2seq是一个Encoder–Decoder 结构的网络,它的输入是一个序列,输出也是一个序列, Encoder 中将一个可变长度的信号序列变为固定长度的向量表达,Decoder 将这个固定长度的向量变成可变长度的目标的信号序列。
注:Cell可以用 RNN ,GRU,LSTM 等结构。
相当于将RNN模型当中的s^{0}s0输入变成一个encoder
4.3.1.2 条件语言模型理解
- 1、编解码器作用
- 编码器的作用是把一个不定长的输入序列x_{1},\ldots,x_{t}x1,…,xt,输出到一个编码状态CC
- 解码器输出y^{t}yt的条件概率将基于之前的输出序列y^{1}, y^{t-1}y1,yt−1和编码状态CC
argmax {P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T)argmaxP(y1,…,yT′∣x1,…,xT),给定输入的序列,使得输出序列的概率值最大。
- 2、根据最大似然估计,最大化输出序列的概率
{P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T) = \prod_{t'=1}^{T'} {P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, x_1, \ldots, x_T) = \prod_{t'=1}^{T'} {P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, {C})P(y1,…,yT′∣x1,…,xT)=∏t′=1T′P(yt′∣y1,…,yt′−1,x1,…,xT)=∏t′=1T′P(yt′∣y1,…,yt′−1,C)
由于这个公式需要求出:P(y^{1}|C) * P(y^{2}|y^{1},C)*P(y^{3}|y^{2},y^{2},y^{1},C)...P(y1∣C)∗P(y2∣y1,C)∗P(y3∣y2,y2,y1,C)...这个概率连乘会非常非常小不利于计算存储,所以需要对公式取对数计算:
\log{P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T) = \sum_{t'=1}^{T'} \log{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, {C})logP(y1,…,yT′∣x1,…,xT)=∑t′=1T′logP(yt′∣y1,…,yt′−1,C)
所以这样就变成了P(y^{1}|C)+ P(y^{2}|y^{1},C)+P(y^{3}|y^{2},y^{2},y^{1},C)...P(y1∣C)+P(y2∣y1,C)+P(y3∣y2,y2,y1,C)...概率相加。这样也可以看成输出结果通过softmax就变成了概率最大,而损失最小的问题,输出序列损失最小化。
4.3.1.3 应用场景
神经机器翻译(NMT)
聊天机器人
接下来我们来看注意力机制,那么普通的seq2seq会面临什么样的问题?
4.3.2 注意力机制
4.3.2.1 长句子问题
对于更长的句子,seq2seq就显得力不从心了,无法做到准确的翻译,一下是通常BLEU的分数随着句子的长度变化,可以看到句子非常长的时候,分数就很低。
本质原因:在Encoder-Decoder结构中,Encoder把所有的输入序列都编码成一个统一的语义特征CC再解码,因此, CC中必须包含原始序列中的所有信息,它的长度就成了限制模型性能的瓶颈。当要翻译的句子较长时,一个CC可能存不下那么多信息,就会造成翻译精度的下降。
4.3.2.2 定义
- 建立Encoder的隐层状态输出到Decoder对应输出y所需要的上下文信息
- 目的:增加编码器信息输入到解码器中相同时刻的联系,其它时刻信息减弱
4.3.2.3 公式
注意上述的几个细节,颜色的连接深浅不一样,假设Encoder的时刻记为tt,而Decoder的时刻记为t^{'}t′。
1、{c}_{t'} = \sum_{t=1}^T \alpha_{t' t}{h}_tct′=∑t=1Tαt′tht
- \alpha_{t^{'}t}αt′t为参数,在网络中训练得到
- 理解:蓝色的解码器中的cell举例子
- \alpha_{4}{1}h_{1}+\alpha_{4}{2}h_{2} + \alpha_{4}{3}h_{3} + \alpha_{4}{4}h_{4} = c_{4}α41h1+α42h2+α43h3+α44h4=c4
2、\alpha_{t^{'}t}αt′t的N个权重系数由来?
- 权重系数通过softmax计算:\alpha_{t' t} = \frac{\exp(e_{t' t})}{ \sum_{k=1}^T \exp(e_{t' k}) },\quad t=1,\ldots,Tαt′t=∑k=1Texp(et′k)exp(et′t),t=1,…,T
- e_{t' t} = g({s}_{t' - 1}, {h}_t)= {v}^\top \tanh({W}_s {s} + {W}_h {h})et′t=g(st′−1,ht)=v⊤tanh(Wss+Whh)
- e_{t' t}et′t是由t时刻的编码器隐层状态输出和解码器t^{'}-1t′−1时刻的隐层状态输出计算出来的
- ss为解码器隐层状态输出,hh为编码器隐层状态输出
- v,W_{s},W_{h}v,Ws,Wh都是网络学习的参数
4.3.3 机器翻译案例
4.3.3.1 问题
使用简单的“日期转换”任务代替翻译任务,为了不然训练时间变得太长。
网络将输入以各种可能格式(例如“1958年8月29日”,“03/30/1968”,“1987年6月24日”)编写的日期,并将其翻译成标准化的机器可读日期(例如“1958 -08-29“,”1968-03-30“,”1987-06-24“)。使用seq2seq网络学习以通用机器可读格式YYYY-MM-DD输出日期。
4.3.3.2 相关环境与结果演示
pip install faker pip install tqdm pip install babel pip install keras==2.2.4
- faker:生成数据包
- tqdm:python扩展包
- babel:代码装换器
- keras:更加方便简洁的深度学习库
- 为了快速编写代码
4.3.3.4 代码分析
- Seq2seq():
- 序列模型类
- load_data(self,m):加载数据类,选择加载多少条数据
- init_seq2seq(self):初始化模型,需要自定义自己的模型
- self.get_encoder(self):定义编码器
- self.get_decoder(self):定义解码器
- self.get_attention(self):定义注意力机制
- self.get_output_layer(self):定义解码器输出层
- model(self):定义模型整体输入输出逻辑
- train(self, X_onehot, Y_onehot):训练模型
- test(self):测试模型
- 训练
if __name__ == '__main__': s2s = Seq2seq() X_onehot, Y_onehot = s2s.load_data(10000) s2s.init_seq2seq() s2s.train(X_onehot, Y_onehot)
整个数据集特征值的形状: (10000, 30, 37) 整个数据集目标值的形状: (10000, 10, 11) 查看第一条数据集格式:特征值:9 may 1998, 目标值: 1998-05-09 [12 0 24 13 34 0 4 12 12 11 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36] [ 2 10 10 9 0 1 6 0 1 10] one_hot编码: [[0. 0. 0. ... 0. 0. 0.] [1. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 1.] [0. 0. 0. ... 0. 0. 1.] [0. 0. 0. ... 0. 0. 1.]] [[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]] Epoch 1/1 100/10000 [..............................] - ETA: 10:52 - loss: 23.9884 - dense_1_loss: 2.3992 - dense_1_acc: 0.3200 - dense_1_acc_1: 0.0000e+00 - dense_1_acc_2: 0.0100 - dense_1_acc_3: 0.1300 - dense_1_acc_4: 0.0000e+00 - dense_1_acc_5: 0.0400 - dense_1_acc_6: 0.0900 - dense_1_acc_7: 0.0000e+00 - dense_1_acc_8: 0.3500 - dense_1_acc_9: 0.1100 200/10000 [..............................] - ETA: 5:27 - loss: 23.9289 - dense_1_loss: 2.3991 - dense_1_acc: 0.2550 - dense_1_acc_1: 0.0000e+00 - dense_1_acc_2: 0.0050 - dense_1_acc_3: 0.1150 - dense_1_acc_4: 0.0950 - dense_1_acc_5: 0.0250 - dense_1_acc_6: 0.1150 - dense_1_acc_7: 0.0800 - dense_1_acc_8: 0.3400 - dense_1_acc_9: 0.1050
测试
if __name__ == '__main__': s2s = Seq2seq() X_onehot, Y_onehot = s2s.load_data(10000) s2s.init_seq2seq() s2s.train(X_onehot, Y_onehot) # s2s.test()
source: 1 March 2001 output: 2001-03-01