条件生成-1|深度学习(李宏毅)(十四)

简介: 条件生成-1|深度学习(李宏毅)(十四)

一、Generation


generation的目的是生成有结构的某些东西,比如句子、图片等。比如我们可以使用RNN来生成句子,如下图,我们在训练RNN时可以输入当前字然后使得RNN输出下一个字,如此就可以使得模型输出一个句子:


3)BB[FC7FE`K)X_(4~21L$3.png

                                                    生成句子


我们也可以尝试将生成句子的RNN模型用于生成图片上,在这里我们将图片的每个像素看做一个word,如下图:


@YO}P0QSFYZM108[5HU{PLG.png

                                                     生成图片


然后使用类似上面的模型进行训练就可以生成一张图片,但是使用这种方法有一个问题。我们可以看到图片的像素是按照以下顺序产生的:


0Z00TKB239Q$PX3~@3{R@LO.png

        生成图片


显然模型忽视了像素之间的几何关系,也就是说比如左边的像素收到上面的像素影响比收到上面右方像素的影响要大一些,而使用这种网络结构很难学到这种关系,我们希望模型能够学习到以下像素之间的相互影响关系:


61GX]VUZUG{AW3%5F}2MC]5.png

         生成图片


为了解决这个问题,我们可以采用3D的LSTM来进行图片的生成,如下图,每个3D LSTM的cell接受来自三个方向的输入然后向三个方向进行输出:


ABA}6@TYH2_HBQ@RIYES]W8.png

                                3D LSTM


使用一个卷积核在图片上移动就可以按上述像素影响关系产生图片:


32WQOX(N)(]G~MPU`RAF@6D.png

                                         生成图片


二、Conditional Generation


条件生成要求神经网络不只是随机地生成一些图片或者句子等,还要根据需要产生相应的输出,比较典型的场景如下:


)PUN(R~MO]S}6RST_MQ[FFC.png

                                                  条件生成


  1. Image Caption Generation


比如一个任务是为模型输入一张图片,模型要产生图片的说明,一个可行的做法是将图片丢进一个CNN里提取特征,然后将提取到的特征向量输入到RNN里,使得RNN产生这张图片的描述,为了使得模型每个时刻的输出都会考虑到图片的影响,可以将特征向量输入给每一个时刻,该过程的流程图如下:


2`ENPS~Z5}EQNT`9VG@30PP.png

                               Image Caption Generation


  1. Machine translation


在做机器翻译时可以将中文输入到一个RNN中,然后取最后一个时间点的输出,这个输出可以认为包含了整个中文句子的信息,然后将输出的这个向量输入到另一个RNN中的每个时间点,使这个RNN来输出对应的英文翻译结果。


需要注意这里的两个RNN是一起训练的(jointly train),同时两个RNN的参数既可以使用同样的参数,也可以使用不一样的参数。通常在数据量比较小时使用同样的参数即可,这样参数比较少,比较容易避免过拟合。


[218P~L2RX(@`Z{0`W820AN.png

                                          Machine translation


  1. Chat-bot


在做聊天机器人时经常会遇到一个问题,就是模型只考虑上一个时刻的状态,而忽略了之前时刻的状态,这也就导致了下图中模型会和用户打两次招呼的问题。因此我们需要模型能够考虑到longer context的影响。解决这个问题的方式是使用一个双层的Encoder,也就是再加入一个RNN来记录对话,第一层RNN会输入人和机器聊天的句子然后获得每个句子对应的输出,然后第二层RNN会将这些输出作为输入然后将自己的输出作为另一个RNN的输入来获得Chat-bot的回答:


4QUZE_%E$F62TQ6$WH_2R$L.png

                                                Chat-bot


三、Attention


  1. 基本过程


同样以机器翻译为例,在上面的机器翻译方法中,我们每次都把由Encoder RNN提取的包含整个序列信息的向量输入到Decoder RNN中的每个时间点,但是可以想象如果在翻译“machine”时,只关注“机器”这两个字的信息而不是整个序列的信息,翻译的效果会更好一些,同样地翻译“learning”时只关注“学习”。Attention-based model可以帮助我们做到这件事。


$$H~6MV530}QHK{W1P1A9{P.png

                                              Attention


对于使用的match函数可以自定义来设计,举例来说有以下但不限于这些方式:


O`%8CPP)6LGQR[M290@NEGS.png

如果match函数中有参数,则这些参数应该和整个网络一起学习得到:


1DZ9F_BESFY[0R_8_U0[CV2.png

                     match function

Z$~~[YQD(]HV54N%{1%)SK7.png

                                                        Attention

UU{]@M$1$W~@Y0JV9JIVGQK.png

RJA}C2D0Q~[F_L$G5G)1GR4.png

                                    Attention


需要注意的是这里可以用RNN的隐层状态作为VQ`8$_]ECS{~]@9YYA(R(UT.png,但是这里的方法并不固定,比如还可以用RNN的隐层状态再通过一个隐藏层的输出作为VQ`8$_]ECS{~]@9YYA(R(UT.png,具体方法可以自行选择,重要的是理解Attention机制。


上述过程将会一直进行下去直到翻译结束得到“.”才结束。


  1. 应用


  • Speech Recognition


可以将Attenton机制应用在语音辨识任务上,这个任务是指将输入的声音讯号转换成文字。


下图表示了使用Attenton来做语音辨识的效果,下图中上面的声音讯号可以看做一个个列向量,下面的列向量表示了每个时刻attention的结果,黑白色块的颜色代表了attention的权重大小:


VD1WAJWLO8$XOUV_PBFUDJ8.png

                                                      语音辨识


下表对比了Attenton的方法与传统的语音辨识方法:

@)6GISJ%XU`56LUKS]{7OU0.png

                                   对比


WER指字错误率,这个值越低表示识别效果越好,可以看到LAS的方法比起传统方法还是逊色了一些,LAS是指论文《Listen,Attend and Spell》中的方法,即Attention的方法。


虽然效果上不如传统方法,但是Attention的这种方法比起传统方法更加地简便,只需要直接进行训练就可以了。


  • Image Caption Generation


在做图片的内容描述时也可以用到Attention这种技术,首先需要使用CNN提取特征向量,这里使用的是在将卷积核的输出作为特征向量,也就是未做flatten(展平)之前的卷积核的输出:


ZJAO)OTQ72T9JE{E%P02XXI.png

                        特征向量提取


然后使用attention来使得神经网络实现“看图说话”的功能:


LQ`PMVWM_~6R$P754@S}JMH.png

                              Attention


下图展示了效果,每张图的白色的地方是指在生成划线的单词时所attent到的地方:


IP]UTD51SF(CMKTQA5SUTTF.png

                                                  效果


这里也有一些失败的例子,使用Attention也可以看到为什么会产生失败的结果:


TL[7QK5[Y_)X9FMEB7N7J4V.png

                                                 失败的例子


也可以用Attention来看一段视频产生一些说明,如下图,Ref指的正确的说明,柱状图和词的颜色表明了网络在产生这个说明时所attent到的视频帧:

(T111[083_85R(405Y}}0QG.png

                                                   视频描述


四、Memory Network


将Attention的机制应用在Memory Network上会有不错的效果,使用Memory Network可以进行阅读理解,主要形式是我们有一篇文章(document),一个问句(query),然后神经网络要生成一个答案(answer)。


首先我们需要将document中的句子表示成一些向量,这里可以使用paragraph2vec或者bag of words都可以,同样地query也要表示成一个向量。然后我们需要使用query向量在document的若干向量上做Attention,然后得到权值%NEMYB1U`DAY7VKOQ2(AQ1V.png,然后将document的若干向量做weighted sum,加和后要将得到的向量和query向量一起输入进一个DNN中,最终输出answer。这整个过程,包括对文章做Embedding的部分是可以一起训练的(jointly train)。整个过程流程图如下:


K$W(9~AFRERBBV`@3BXE097.png

                                              Memory Network


~1M1{V{4~BRIJ[_M_X_GEME.png

                                       Memory Network


接下来将上图展开具体介绍一下Hopping的过程。在下图中,我们可以将Hopping的过程看成多层网络叠加在一起。下图以两层为例,我们将Attention的结果再与query加起来,然后再在另外两组document的Embedding上做Attention,然后输入到DNN中然后输出answer。图中docment的四组Embedding向量可以用四组不一样的,也可以用两组一样的,这取决于设置的参数共享方式,使用两组一样的可以减少参数,使用四组不一样的可以增强效果:

DGQHJ[{2YXNBC[)(UX`[BZK.png

                                        Hopping


五、Neural Turing Machine


相比于Memory Network,Neural Turing Machine不仅可以从memory中读取信息,也可以通过Attention机制来修改memory。


_VCBX2AS8]87LNJ%ZIZ)L}J.png

                                      Neural Turing Machine

通过_QIA(5T3F`}7CWW6Z$7QAZL.png和memory计算余弦相似度来获得Attention的值是一种简化版本,真正的Neural Turing Machine获取Attention的值是通过下图的方式:


8JOGS%EW~E4S~9_JB32OJF3.png

                              Attention

F}3SJB%X50NBN@02NXQ%T6Q.png

                                          修改memory

然后用新的memory和Attention的权重做上述同样的过程产生QI~5Z]%[JPLOQ9}G}][NU2S.png。如果使用的controller是循环神经网络的话也会产生其本身的隐层状态0K)D%`N(`M%)VN~PA1OCD[G.png,然后将这个隐层状态输入到下一个时间点:


9VFFQD806EW{AJO0A@B)46V.png

                                     Neural Turing Machine


以上就是对Neural Turing Machine的简单介绍。

相关文章
|
11月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
1639 2
|
4月前
|
人工智能 缓存 搜索推荐
1688图片搜索API接口解析与 Python实战指南
1688图片搜索API接口支持通过上传图片搜索相似商品,适用于电商及商品推荐场景。用户上传图片后,经图像识别提取特征并生成关键词,调用接口返回包含商品ID、标题和价格的相似商品列表。该接口需提供图片URL或Base64编码数据,还可附加分页与筛选参数。示例代码展示Python调用方法,调试时建议使用沙箱环境测试稳定性,并优化性能与错误处理逻辑。
|
7月前
|
存储 数据采集 数据库
Python爬虫实战:股票分时数据抓取与存储
Python爬虫实战:股票分时数据抓取与存储
|
11月前
|
Web App开发 Java iOS开发
webp详解
WebP是一种由谷歌开发的图像文件格式,旨在提供更高效的图像压缩方法,以加快网页加载速度。它支持有损和无损压缩模式,并且在相同的视觉质量下,相比JPEG和PNG等格式,文件大小更小,从而优化了网络传输效率。此外,WebP还支持透明度和动画图像。
|
10月前
|
弹性计算 安全 网络安全
阿里云服务器租用流程,四种阿里云服务器租用方式图文教程参考
阿里云服务器可以通过自定义租用、一键租用、云市场租用和活动租用四种方式去租用,不同的租用方式适合不同的用户群体,例如我们只是想租用一款配置较低且可以快速部署应用的云服务器,通常可以选择一键租用或者云市场租用,本文为大家展示不同租用方式的适合对象以及租用流程,以供初次租用阿里云服务器的用户参考和选择。下面是阿里云服务器租用的图文操作步骤。
10660 2
|
网络协议 网络架构
计算机网络:思科实验【5-IPv4地址——分类地址与划分子网】
计算机网络:思科实验【5-IPv4地址——分类地址与划分子网】
|
存储 JSON 前端开发
【Java】用@JsonFormat(pattern = “yyyy-MM-dd“)注解,出生日期竟然年轻了一天
在实际项目中,使用 `@JsonFormat(pattern = "yyyy-MM-dd")` 注解导致出生日期少了一天的问题,根源在于夏令时的影响。本文详细解析了夏令时的概念、`@JsonFormat` 注解的使用方法,并提供了三种解决方案:在注解中添加 `timezone = GMT+8`、修改 JVM 参数 `-Duser.timezone=GMT+08`,以及使用 `timezone = Asia/Shanghai
1299 0
【Java】用@JsonFormat(pattern = “yyyy-MM-dd“)注解,出生日期竟然年轻了一天
idea启动java服务报错OutOfMemoryError: GC overhead limit exceeded解决方法
idea启动java服务报错OutOfMemoryError: GC overhead limit exceeded解决方法
3045 1
|
存储 SQL 算法
一文教你玩转 Apache Doris 分区分桶新功能|新版本揭秘
一文教你玩转 Apache Doris 分区分桶新功能|新版本揭秘
946 0
|
机器学习/深度学习 人工智能 自然语言处理
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)