小白总结Transformer模型要点(一)(上)

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 本文主要总结了Transformer模型的要点,包含模型架构各部分组成和原理、常见问题汇总、模型具体实现和相关拓展学习。

前言

本文主要总结了Transformer模型的要点,包含模型架构各部分组成和原理、常见问题汇总、模型具体实现和相关拓展学习。

一、模型架构

0.背景知识

seq2seq模型:

由Encoder和Decoder共同组成,之间由Attention机制来建立关联性。

可以分为3类:

CNN

  • 权重共享(平移不变性、可并行计算)

  • 滑动窗口(局部关联性建模)

  • 对相对位置敏感、对绝对位置不敏感

RNN(依次有序递归建模)

  • 对顺序敏感(当前的输入依赖于上一层的输出)

  • 串行计算耗时

  • 长程建模能力弱

  • 单步计算复杂度不变,计算复杂度与序列长度呈线性关系

  • 对相对位置和绝对位置都敏感

TRM

  • 无局部假设(可并行计算,对相对位置不敏感)

  • 无有序假设

需要增加位置编码来反映位置变化对特征的影响;

对绝对位置不敏感。

  • 任意两字符均可建模

擅长长短程建模;

自注意力机制需要序列长度的平方级别复杂度。

1.整体架构

2345_image_file_copy_151.jpg

6个Encoder的结构相同,但不是完全相同,只是结构相同、参数不同,在训练时不是训练一个Encoder、再复制到6份,而是6个Encoder都独立训练,这与预训练模型ALBERT共享Transformer中的某些层的参数达到减少BERT参数量的目的是有所区别的;

6个Decoder的结构也相同,参数不同,与Encoder类似。

2345_image_file_copy_153.jpg

输入和输出的说明如下:

2345_image_file_copy_154.jpg

可以看到,有3个输入,解码端的真实标签,与解码端的输出计算损失,同时解码端是不能并行的,只能顺序执行,因为这一层解码器(当前时刻)的输入取决于(依赖于)上一层(上一时刻)的输出,因此真实标签与解码端的输入是错了一位的;

同时,在实际应用中,为了加快训练时的收敛速度,此时会使用到Teacher Forcing,即将真实标签与解码端原来的输入一起输入,为了不影响训练(即在训练当前单词时不看到后面的单词),此时就需要将当前单词后面的单词全部mask住,以达到更好的预测效果。

实际计算中是将多个句子作为一个Batch来处理的,可以使用矩阵来加快计算,但是句子的长度可能不一致,此时超过最大长度的就被舍弃,不够最大长度的就用Padding填充字符来填充,在注意力层中将其置为-∞,避免其对其他词产生影响:

2345_image_file_copy_156.jpg

2.Embedding和位置编码

Encoder包含3部分:

  • 输入部分
  • 注意力机制
  • 前馈神经网络

2345_image_file_copy_157.jpg

输入部分包含Embedding和位置嵌入(位置编码)。

Embedding:

2345_image_file_copy_158.jpg

输入部分包含Embedding和位置嵌入(位置编码)。

Embedding:

2345_image_file_copy_159.jpg

使用随机初始化word2vec都可以,具体可以根据实际使用到的情况选择;

Embedding由稀琉的one-hot进入一个不带bias的FFN得到一个稠密的连续向量,用来表征单词。

2345_image_file_copy_160.jpg

从RNN到位置编码:

RNN的结构天然与时序关系很符合,可以实现先处理某些数据、再处理另外的数据的效果。

(1)RNN的参数共享

RNN的U输入参数、W隐层参数和输出V是一套参数,对于所有的time step都共享一套参数,例如对于NLP任务来说,所有的单词都共用了这一套参数。

(2)RNN的梯度消失

不是因为连乘效应造成了梯度衰减,这里的梯度是梯度的和,其梯度消失不是指梯度逐渐趋近于0,而是总梯度被近距离梯度主导、被远距离梯度忽略不计。

在TRM中:

实现了并行化,可以一起处理多个单词、而不是逐个地处理,这样可以加快处理速度,但是会忽略掉单词之间的相对位置信息,这时候就需要位置编码。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wKGspfzc-1654266200425)(image/image_h8VM501H7C.png)]

之所以选择sin/cos来表征位置编码,是因为:

每个位置都是固定的

对于不同的句子,相同位置的距离一致

泛化能力较强,可以推广到更长的句子

通过使用sin/cos,可以使得pe(pos+k)可以写成pe(pos)和pe(k)的线性组合。

将位置编码和嵌入相加(维度相同):

2345_image_file_copy_162.jpg

绝对位置向量包含着相对位置信息:

2345_image_file_copy_163.jpg

同时,虽然在注意力机制中会消除相对位置信息,但是由于残差连接的存在,因此位置编码表征的位置信息可以向更高的层传递(流入深层),因此位置信息是一直存在的,不会消失。

这里之所以选择正弦和余弦函数,是因为正余弦函数可以拆解,从而将以后时刻的位置表示为前面位置的线性组合,因此可以增加泛化能力。

具体实现位置编码时,有2种方式:

(1)将得到的PE矩阵直接与Word Embedding相加

(2)将得到的PE来构造Embedding,并对词语序列进行Embedding编码,得到位置的Embedding,再与Word Embedding相加。

相关文章
|
负载均衡 算法 网络协议
slb监听协议UDP
SLB的UDP监听器适用于实时性高、数据完整性要求低的场景,如视频流和在线游戏。它无连接、不可靠,不保证数据顺序和重传,适合延迟敏感应用。SLB进行UDP会话保持依赖应用层协议或数据包标识符,使用定制健康检查检测后端服务器状态,并支持多种负载均衡算法。配置时注意网络环境对UDP的支持,确保流量畅通。
240 4
|
9月前
|
机器学习/深度学习 数据采集 自然语言处理
Transformer 学习小结(输出输入)
在模型处理中,输入文本需经预处理,包括分词、词汇表构建及填充(padding),并使用填充掩码避免无效计算。位置嵌入为Transformer提供顺序信息,编码器通过自注意力机制和前馈网络处理输入序列。输出处理中,解码器根据编码器输出生成目标序列,使用序列掩码防止信息泄露,逐步生成单词,并在测试阶段采用贪婪或束搜索优化输出。
|
自然语言处理 API Python
LLaMA
【9月更文挑战第26天】
507 63
|
关系型数据库 MySQL 应用服务中间件
windows服务器自带IIS搭建网站并发布公网访问【内网穿透】-1
windows服务器自带IIS搭建网站并发布公网访问【内网穿透】
|
存储 分布式计算 Serverless
阿里云 EMR Serverless Spark 版开启免费公测
EMR Serverless Spark 版免费公测已开启,预计于2024年06月25日结束。公测阶段面向所有用户开放,您可以免费试用。
1654 5
|
存储 JSON 安全
[浏览器系列] : 客户端本地存储
[浏览器系列] : 客户端本地存储
209 2
[浏览器系列] : 客户端本地存储
|
关系型数据库 MySQL Apache
Ubuntu22.04搭建LAMP环境
LAMP是一个用于构建Web应用程序的技术堆栈,你可以用它开发很多Web程序,比如WordPress。如果你想手工在VPS上搭建WordPress的话,那么你就需要先搭建LAMP环境。这篇文章讲解如何在Ubuntu22.04上搭建LAMP环境。首先,你需要先注册一台VPS服务器,然后登录VPS安装Apache服务、安装MySQL数据库,以及安装PHP。
311 0
Ubuntu22.04搭建LAMP环境
|
机器学习/深度学习 人工智能 算法
【机器学习】平均绝对误差 (MAE) 与均方误差 (MSE) 有什么区别?
【5月更文挑战第17天】【机器学习】平均绝对误差 (MAE) 与均方误差 (MSE) 有什么区别?
|
机器学习/深度学习 自然语言处理 索引
【Transformer系列(4)】Transformer模型结构超详细解读
【Transformer系列(4)】Transformer模型结构超详细解读
1866 0
【Transformer系列(4)】Transformer模型结构超详细解读
|
存储 前端开发 JavaScript
驾校预约|基于Vue+Springboot驾校预约系统的设计与实现
驾校预约|基于Vue+Springboot驾校预约系统的设计与实现
407 0

热门文章

最新文章