Seq2Seq、SeqGAN、Transformer…你都掌握了吗?一文总结文本生成必备经典模型(2)

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
简介: Seq2Seq、SeqGAN、Transformer…你都掌握了吗?一文总结文本生成必备经典模型

SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient


这篇文章用对抗网络实现了离散序列数据的生成模型,解决了对抗生成网络难应用于nlp领域的问题,并且在文本生成任务上有优异表现。

图5. SeqGAN图示。左图:D通过真实数据和G生成的数据进行训练。G通过策略梯度进行训练,最终的奖励信号由D提供,并通过蒙特卡洛搜索传递回中间的行动值

序列生成问题表示如下。给定一个真实世界结构化序列的数据集,训练一个θ参数化的生成模型Gθ来生成一个序列Y_1:T=(y_1, ..., y_t, ..., y_T),y_t∈Y,其中Y是候选标记的词汇表。作者根据强化学习来解释这个问题。在时间步骤t中,状态s是当前产生的token(y_1, ..., y_t-1),行动a是要选择的下一个token y_t。因此,策略模型Gθ(y_t|Y_1:t-1)是随机的,而在选择了一个行动后,状态转换是确定的,即如果当前状态s=Y_1:t-1,行动a=y_t,则下一个状态s’=Y_1:t的(δ_s,s’)^a=1;对于其他下一个状态s'',(δ_s,s'')^a=0。此外,还训练了一个φ参数化的鉴别模型Dφ,为改进生成器Gθ提供指导。Dφ(Y_1:T )是一个概率,表示一个序列Y_1:T来自真实序列数据的可能性。通过提供来自真实序列数据的正样本和来自生成模型Gθ生成的合成序列的负样本来训练鉴别模型Dφ。同时,生成模型Gθ通过采用策略梯度和MC搜索,根据从鉴别模型Dφ得到的预期最终奖励进行更新。奖励是通过骗过鉴别模型Dφ的可能性来估计的。生成器模型(策略)Gθ(y_t|Y_1:t-1)的目标是从起始状态s'生成一个序列,使其预期的最终奖励最大化:


(Q_Dφ(s, a))^Gθ是一个序列的行动价值函数,即从状态s开始,采取行动a,然后遵循策略Gθ的预期累积奖励。序列的目标函数的合理性是:从一个给定的初始状态开始,生成器的目标是生成一个能使鉴别器认为它是真实的序列。使用REINFORCE算法,将鉴别器Dφ((Y_1:T)^n)估计的真实概率作为奖励:


由上式,state指的当前timestep之前的decode结果,action指的当前待解码词,D网络鉴别伪造数据的置信度即为奖励,伪造数据越逼真则相应奖励越大,但该奖励是总的奖励,分配到每个词选择上的reward则采用了以下的近似方法:


即当解码到t时,对后面T-t个timestep采用蒙特卡洛搜索搜索出N条路径,将这N条路径分别和已经decode的结果组成N条完整输出,然后将D网络对应奖励的平均值作为reward。当t=T时无法再向后探索路径,所以直接以完整decode结果的奖励作为reward。蒙特卡洛搜索是指在选择下一个节点的时候用蒙特卡洛采样的方式,而蒙特卡洛采样是指根据当前输出词表的置信度随机采样。完整算法流程如下:

随机初始化G网络和D网络参数;通过MLE预训练G网络,目的是提高G网络的搜索效率;通过G网络生成部分负样预训练D网络;通过G网络生成sequence用D网络去评判,得到reward:


根据上式(4)计算得到每个action选择得到的奖励并求得累积奖励的期望,以此为loss function,并求导对网络进行梯度更新。其中,下式是标准的D网络误差函数,训练目标是最大化识别真实样本的概率,最小化误识别伪造样本的概率:


最后,GAN网络的误差函数如上,循环以上过程直至收敛。

当前 SOTA!平台收录 SeqGAN 共 22 个模型实现资源,支持的主流框架包含 PyTorch、TensorFlow 等。

项目 SOTA!平台项目详情页
SeqGAN 前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seqgan

Attention is all you need


2017 年,Google 机器翻译团队发表的《Attention is All You Need》完全抛弃了RNN和CNN等网络结构,而仅仅采用Attention机制来完成机器翻译任务,并且取得了很好的效果,注意力机制也成为了研究热点。大多数竞争性神经序列转导模型都有一个编码器-解码器结构。编码器将输入的符号表示序列(x1, ..., xn)映射到连续表示的序列z=(z1, ..., zn)。给定z后,解码器每次生成一个元素的符号输出序列(y1, ..., ym)。在每个步骤中,该模型是自动回归的,在生成下一个符号时,将先前生成的符号作为额外的输入。Transformer遵循这一整体架构,在编码器和解码器中都使用了堆叠式自注意力和点式全连接层,分别在图6的左半部和右半部显示。

图6. Transformer架构


编码器。编码器是由N=6个相同的层堆叠而成。每层有两个子层。第一层是一个多头自注意力机制,第二层是一个简单的、按位置排列的全连接前馈网络。在两个子层的每一个周围采用了一个残差连接,然后进行层的归一化。也就是说,每个子层的输出是LayerNorm(x + Sublayer(x)),其中,Sublayer(x)是子层本身实现的函数。为了方便这些残差连接,模型中的所有子层以及嵌入层都会生成尺寸为dmodel=512的输出。


解码器。解码器也是由N=6个相同的层组成的堆栈。除了每个编码器层的两个子层之外,解码器还插入了第三个子层,它对编码器堆栈的输出进行多头注意力。与编码器类似,在每个子层周围采用残差连接,然后进行层归一化。进一步修改了解码器堆栈中的自注意力子层,以防止位置关注后续位置。这种masking,再加上输出嵌入偏移一个位置的事实,确保对位置i的预测只取决于小于i的位置的已知输出。


Attention。注意力函数可以描述为将一个查询和一组键值对映射到一个输出,其中,查询、键、值和输出都是向量。输出被计算为值的加权和,其中分配给每个值的权重是由查询与相应的键的兼容性函数计算的。在Transformer中使用的Attention是Scaled Dot-Product Attention, 是归一化的点乘Attention,假设输入的query q 、key维度为dk,value维度为dv , 那么就计算query和每个key的点乘操作,并除以dk ,然后应用Softmax函数计算权重。Scaled Dot-Product Attention的示意图如图7(左)。


图7. (左)按比例的点乘法注意力。(右)多头注意力由几个平行运行的注意力层组成


如果只对Q、K、V做一次这样的权重操作是不够的,这里提出了Multi-Head Attention,如图7(右)。具体操作包括:

  1. 首先对Q、K、V做一次线性映射,将输入维度均为dmodel 的Q、K、V 矩阵映射到Q∈Rm×dk,K∈Rm×dk,V∈Rm×dv;
  2. 然后在采用Scaled Dot-Product Attention计算出结果;
  3. 多次进行上述两步操作,然后将得到的结果进行合并;
  4. 将合并的结果进行线性变换。

在完整的架构中,有三处Multi-head Attention模块,分别是:

  1. Encoder模块的Self-Attention,在Encoder中,每层的Self-Attention的输入Q=K=V , 都是上一层的输出。Encoder中的每个位置都能够获取到前一层的所有位置的输出。
  2. Decoder模块的Mask Self-Attention,在Decoder中,每个位置只能获取到之前位置的信息,因此需要做mask,其设置为−∞。
  3. Encoder-Decoder之间的Attention,其中Q 来自于之前的Decoder层输出,K、V 来自于encoder的输出,这样decoder的每个位置都能够获取到输入序列的所有位置信息。

在进行了Attention操作之后,encoder和decoder中的每一层都包含了一个全连接前向网络,对每个位置的向量分别进行相同的操作,包括两个线性变换和一个ReLU激活输出:


因为模型不包括recurrence/convolution,因此是无法捕捉到序列顺序信息的,例如将K、V按行进行打乱,那么Attention之后的结果是一样的。但是序列信息非常重要,代表着全局的结构,因此必须将序列的token相对或者绝对位置信息利用起来。这里每个token的position embedding 向量维度也是dmodel=512, 然后将原本的input embedding和position embedding加起来组成最终的embedding作为encoder/decoder的输入。其中,position embedding计算公式如下:


其中,pos表征位置,i表征维度。也就是说,位置编码的每个维度对应于一个正弦波。波长形成一个从2π到10000-2π的几何级数。选择这个函数是因为假设它可以让模型很容易地学会通过相对位置来参加,因为对于任何固定的偏移量k,PE_pos+k可以表示为PE_pos的线性函数。


当前 SOTA!平台收录 Transformer 共 9 个模型实现资源,支持的主流框架包含 TensorFlow、PyTorch等。

项目 SOTA!平台项目详情页
Transformer 前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/transformer-2

前往 SOTA!模型资源站(sota.jiqizhixin.com)即可获取本文中包含的模型实现代码、预训练模型及API等资源。 

网页端访问:在浏览器地址栏输入新版站点地址 sota.jiqizhixin.com ,即可前往「SOTA!模型」平台,查看关注的模型是否有新资源收录。

移动端访问:在微信移动端中搜索服务号名称「机器之心SOTA模型」或 ID 「sotaai」,关注 SOTA!模型服务号,即可通过服务号底部菜单栏使用平台功能,更有最新AI技术、开发资源及社区动态定期推送。

相关文章
|
6月前
|
存储 自然语言处理 前端开发
2025年大模型发展脉络:深入分析与技术细节
本文深入剖析2025年大模型发展脉络,涵盖裸模型与手工指令工程、向量检索、文本处理与知识图谱构建、自动化提示生成、ReAct多步推理及AI Agent崛起六大模块。从技术细节到未来趋势,结合最新进展探讨核心算法、工具栈与挑战,强调模块化、自动化、多模态等关键方向,同时指出计算资源、数据质量和安全伦理等问题。适合关注大模型前沿动态的技术从业者与研究者。
1952 9
|
12月前
|
负载均衡 算法 应用服务中间件
nginx反向代理与负载均衡
nginx反向代理与负载均衡
324 2
|
12月前
|
搜索推荐 安全 数据挖掘
如何利用商品详情数据挖掘消费者的潜在需求?
本文介绍了利用商品详情数据挖掘消费者潜在需求的六种方法,包括分析商品属性信息、研究消费者评价反馈、关注搜索浏览行为、对比竞争对手数据、分析购买行为及利用数据挖掘技术进行综合分析,旨在帮助企业精准捕捉市场需求,优化产品和服务。
|
11月前
|
人工智能 前端开发 IDE
通义灵码一周年测评:@workspace 和 @terminal 新功能体验分享
作为一名前端开发工程师,我近期体验了通义灵码的@workspace和@terminal新功能。@workspace通过智能解析项目结构,帮助快速上手新项目;@terminal则提供内置命令行环境,简化代码调试和系统管理。这两项功能显著提升了开发效率和代码管理的便捷性,是前端开发的得力助手。
通义灵码一周年测评:@workspace 和 @terminal 新功能体验分享
|
监控 Ubuntu Unix
Linux |Nethogs 监控网络使用情况
Linux |Nethogs 监控网络使用情况
Linux |Nethogs 监控网络使用情况
|
存储 算法 安全
MD5哈希算法:原理、应用与安全性深入解析
MD5哈希算法:原理、应用与安全性深入解析
|
运维 Ubuntu JavaScript
【Linux】Linux命令快速学习神器tldr、cheat介绍和使用(一)
【Linux】Linux命令快速学习神器tldr、cheat介绍和使用
620 0
【Linux】Linux命令快速学习神器tldr、cheat介绍和使用(一)
|
机器学习/深度学习 算法 测试技术
蚂蚁集团开源代码大模型CodeFuse!(含魔搭体验和最佳实践)
蚂蚁集团在刚刚结束的2023外滩大会上开源了代码大模型CodeFuse,目前在魔搭社区可下载、体验。
|
算法 安全 物联网安全
物联网安全|位置隐私保护方法
物联网安全|位置隐私保护方法
910 15
物联网安全|位置隐私保护方法
|
消息中间件 关系型数据库 MySQL
使用Flink的MySQL连接器
使用Flink的MySQL连接器
639 4

热门文章

最新文章