seq2seq与Attention机制(一)

简介: seq2seq与Attention机制(一)

学习目标



  • 目标
  • 掌握seq2seq模型特点
  • 掌握集束搜索方式
  • 掌握BLEU评估方法
  • 掌握Attention机制


  • 应用
  • 应用Keras实现seq2seq对日期格式的翻译


4.3.1 seq2seq



seq2seq模型是在2014年,是由Google Brain团队和Yoshua Bengio 两个团队各自独立的提出来。


4.3.1.1 定义


seq2seq是一个Encoder–Decoder 结构的网络,它的输入是一个序列,输出也是一个序列, Encoder 中将一个可变长度的信号序列变为固定长度的向量表达,Decoder 将这个固定长度的向量变成可变长度的目标的信号序列。


image.png


注:Cell可以用 RNN ,GRU,LSTM 等结构。


相当于将RNN模型当中的s^{0}s0输入变成一个encoder


4.3.1.2 条件语言模型理解


  • 1、编解码器作用


  • 编码器的作用是把一个不定长的输入序列x_{1},\ldots,x_{t}x1,…,xt,输出到一个编码状态CC
  • 解码器输出y^{t}yt的条件概率将基于之前的输出序列y^{1}, y^{t-1}y1,yt−1和编码状态CC


argmax {P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T)argmaxP(y1,…,yT′∣x1,…,xT),给定输入的序列,使得输出序列的概率值最大。


  • 2、根据最大似然估计,最大化输出序列的概率


{P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T) = \prod_{t'=1}^{T'} {P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, x_1, \ldots, x_T) = \prod_{t'=1}^{T'} {P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, {C})P(y1,…,yT′∣x1,…,xT)=∏t′=1T′P(yt′∣y1,…,yt′−1,x1,…,xT)=∏t′=1T′P(yt′∣y1,…,yt′−1,C)


由于这个公式需要求出:P(y^{1}|C) * P(y^{2}|y^{1},C)*P(y^{3}|y^{2},y^{2},y^{1},C)...P(y1∣C)∗P(y2∣y1,C)∗P(y3∣y2,y2,y1,C)...这个概率连乘会非常非常小不利于计算存储,所以需要对公式取对数计算:


\log{P}(y_1, \ldots, y_{T'} \mid x_1, \ldots, x_T) = \sum_{t'=1}^{T'} \log{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, {C})logP(y1,…,yT′∣x1,…,xT)=∑t′=1T′logP(yt′∣y1,…,yt′−1,C)


所以这样就变成了P(y^{1}|C)+ P(y^{2}|y^{1},C)+P(y^{3}|y^{2},y^{2},y^{1},C)...P(y1∣C)+P(y2∣y1,C)+P(y3∣y2,y2,y1,C)...概率相加。这样也可以看成输出结果通过softmax就变成了概率最大,而损失最小的问题,输出序列损失最小化。


4.3.1.3 应用场景


神经机器翻译(NMT)


image.png


聊天机器人

接下来我们来看注意力机制,那么普通的seq2seq会面临什么样的问题?


4.3.2 注意力机制



4.3.2.1 长句子问题


image.png


对于更长的句子,seq2seq就显得力不从心了,无法做到准确的翻译,一下是通常BLEU的分数随着句子的长度变化,可以看到句子非常长的时候,分数就很低。


image.png


本质原因:在Encoder-Decoder结构中,Encoder把所有的输入序列都编码成一个统一的语义特征CC再解码,因此, CC中必须包含原始序列中的所有信息,它的长度就成了限制模型性能的瓶颈。当要翻译的句子较长时,一个CC可能存不下那么多信息,就会造成翻译精度的下降。


4.3.2.2 定义


  • 建立Encoder的隐层状态输出到Decoder对应输出y所需要的上下文信息
  • 目的:增加编码器信息输入到解码器中相同时刻的联系,其它时刻信息减弱


image.png


4.3.2.3 公式


注意上述的几个细节,颜色的连接深浅不一样,假设Encoder的时刻记为tt,而Decoder的时刻记为t^{'}t′。


1、{c}_{t'} = \sum_{t=1}^T \alpha_{t' t}{h}_tct′=∑t=1Tαt′tht


  • \alpha_{t^{'}t}αt′t为参数,在网络中训练得到


  • 理解:蓝色的解码器中的cell举例子


  • \alpha_{4}{1}h_{1}+\alpha_{4}{2}h_{2} + \alpha_{4}{3}h_{3} + \alpha_{4}{4}h_{4} = c_{4}α41h1+α42h2+α43h3+α44h4=c4


2、\alpha_{t^{'}t}αt′t的N个权重系数由来?


  • 权重系数通过softmax计算:\alpha_{t' t} = \frac{\exp(e_{t' t})}{ \sum_{k=1}^T \exp(e_{t' k}) },\quad t=1,\ldots,Tαt′t=∑k=1Texp(et′k)exp(et′t),t=1,…,T


  • e_{t' t} = g({s}_{t' - 1}, {h}_t)= {v}^\top \tanh({W}_s {s} + {W}_h {h})et′t=g(st′−1,ht)=v⊤tanh(Wss+Whh)


  • e_{t' t}et′t是由t时刻的编码器隐层状态输出和解码器t^{'}-1t′−1时刻的隐层状态输出计算出来的
  • ss为解码器隐层状态输出,hh为编码器隐层状态输出
  • v,W_{s},W_{h}v,Ws,Wh都是网络学习的参数


image.png


4.3.3 机器翻译案例



4.3.3.1 问题


使用简单的“日期转换”任务代替翻译任务,为了不然训练时间变得太长。


网络将输入以各种可能格式(例如“1958年8月29日”,“03/30/1968”,“1987年6月24日”)编写的日期,并将其翻译成标准化的机器可读日期(例如“1958 -08-29“,”1968-03-30“,”1987-06-24“)。使用seq2seq网络学习以通用机器可读格式YYYY-MM-DD输出日期。


4.3.3.2 相关环境与结果演示


pip install faker
pip install tqdm
pip install babel
pip install keras==2.2.4


  • faker:生成数据包
  • tqdm:python扩展包
  • babel:代码装换器
  • keras:更加方便简洁的深度学习库
  • 为了快速编写代码


4.3.3.4 代码分析


  • Seq2seq():


  • 序列模型类
  • load_data(self,m):加载数据类,选择加载多少条数据
  • init_seq2seq(self):初始化模型,需要自定义自己的模型
  • self.get_encoder(self):定义编码器
  • self.get_decoder(self):定义解码器
  • self.get_attention(self):定义注意力机制
  • self.get_output_layer(self):定义解码器输出层
  • model(self):定义模型整体输入输出逻辑
  • train(self, X_onehot, Y_onehot):训练模型
  • test(self):测试模型


  • 训练


if __name__ == '__main__':
    s2s = Seq2seq()
    X_onehot, Y_onehot = s2s.load_data(10000)
    s2s.init_seq2seq()
    s2s.train(X_onehot, Y_onehot)


整个数据集特征值的形状: (10000, 30, 37)
整个数据集目标值的形状: (10000, 10, 11)
查看第一条数据集格式:特征值:9 may 1998, 目标值: 1998-05-09
[12  0 24 13 34  0  4 12 12 11 36 36 36 36 36 36 36 36 36 36 36 36 36 36
 36 36 36 36 36 36] [ 2 10 10  9  0  1  6  0  1 10]
one_hot编码: [[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 1.]] [[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
Epoch 1/1
  100/10000 [..............................] - ETA: 10:52 - loss: 23.9884 - dense_1_loss: 2.3992 - dense_1_acc: 0.3200 - dense_1_acc_1: 0.0000e+00 - dense_1_acc_2: 0.0100 - dense_1_acc_3: 0.1300 - dense_1_acc_4: 0.0000e+00 - dense_1_acc_5: 0.0400 - dense_1_acc_6: 0.0900 - dense_1_acc_7: 0.0000e+00 - dense_1_acc_8: 0.3500 - dense_1_acc_9: 0.1100
  200/10000 [..............................] - ETA: 5:27 - loss: 23.9289 - dense_1_loss: 2.3991 - dense_1_acc: 0.2550 - dense_1_acc_1: 0.0000e+00 - dense_1_acc_2: 0.0050 - dense_1_acc_3: 0.1150 - dense_1_acc_4: 0.0950 - dense_1_acc_5: 0.0250 - dense_1_acc_6: 0.1150 - dense_1_acc_7: 0.0800 - dense_1_acc_8: 0.3400 - dense_1_acc_9: 0.1050


测试


if __name__ == '__main__':
    s2s = Seq2seq()
    X_onehot, Y_onehot = s2s.load_data(10000)
    s2s.init_seq2seq()
    s2s.train(X_onehot, Y_onehot)
    # s2s.test()


source: 1 March 2001
output: 2001-03-01
目录
相关文章
|
6月前
|
机器学习/深度学习 缓存 人工智能
45_混合专家模型:MoE架构详解
在大语言模型的发展历程中,参数规模的扩张一直被视为提升性能的主要途径。然而,随着模型参数达到数百亿甚至数千亿级别,传统的密集型模型架构面临着计算资源、训练效率和推理速度等诸多挑战。2025年,混合专家模型(Mixture of Experts,MoE)已成为突破这些限制的关键技术路径。
1210 0
|
监控 网络协议 安全
华为配置防火墙直连路由器出口实验
华为配置防火墙直连路由器出口实验
989 6
|
机器学习/深度学习 自然语言处理 算法
Transformer 模型:入门详解(1)
动动发财的小手,点个赞吧!
14329 1
Transformer 模型:入门详解(1)
|
Shell Android开发
Android10.0(Q) 默认应用设置(电话、短信、浏览器、主屏幕应用)
Android10.0(Q) 默认应用设置(电话、短信、浏览器、主屏幕应用)
1592 0
|
3月前
|
机器学习/深度学习 人工智能 监控
从原理到实践:零代码也能搞定的PPO微调全攻略
本文深入浅出解析PPO(近端策略优化)算法——大模型对齐人类偏好的核心技术。通过“温和教练”比喻、四步原理拆解与实操指南,零基础也能理解其剪切机制、优势函数与稳定训练逻辑,并亲手微调出更懂你的AI。(239字)
305 0
|
自然语言处理 IDE 测试技术
通义灵码——有了它让我的编程效率和质量直线上升!
作为一名大数据开发工程师,我每天与代码和数据打交道,享受解决复杂问题的乐趣。最近,我遇到了一位超级“码”力助手——通义灵码。它不仅是一个简单的代码补全工具,更像是一个拥有高度智慧的编程伙伴,能够理解我的编程意图,给出最合适的建议,大大提升了我的工作效率和编程体验。本文将分享如何在VsCode中安装和使用通义灵码,以及它在我的实际编程工作中发挥的重要作用。
|
JavaScript 索引 前端开发
9.【TypeScript 教程】接口(Interface)
9.【TypeScript 教程】接口(Interface)
341 4
|
消息中间件 Java Spring
Spring Boot与NATS消息系统的集成方法
Spring Boot与NATS消息系统的集成方法
|
机器学习/深度学习 存储 自然语言处理
NLP中的RNN、Seq2Seq与attention注意力机制(下)
NLP中的RNN、Seq2Seq与attention注意力机制(下)
|
监控 测试技术 Shell
APP的CPU,内存和流量如何测试?
APP的CPU,内存和流量如何测试?
753 0

热门文章

最新文章

下一篇
开通oss服务