transformer原理总结

简介: transformer原理总结

Transformer是一个 Sequence to Sequence model,应用了 self-attention它兼具RNN和卷积网络的优点:并行化和关注全局信息

结构图如下:

01b6d1934d6d47aea00e2ff8c8967472.png左侧框是Encoder结构,右侧框是Decoder,上侧是分类网络结构。

1 Encoder

单个Encoder单元流程如下:

  1. 1. 我们可以看到输入.shape=(B, L, D),其中B是batch_size, L是sequence length,D是dim_in。
  2. 2. 之后进入Multi-Head Attention(多头注意力模型)中,输出的维度仍然是(B, L, D)
  3. 3. 再将结果与原输入进行相加(借鉴残差结构),用于防止网络退化,输出维度是(B, L, D)
  4. 4. 再通过Layer normalization,用于对每一层的激活值进行归一化,输出维度是(B, L, D)
  5. 5. 接着进入前馈神经网络Feed Forword Network,并进行Add和Layer Normalization

其中有几个步骤进行详细的总结一下,包括:

  1. 1. Layer Normalization
  2. 2. Feed Forword Network

1.1 Layer Normalization

我们比较熟悉的是批量归一化(Batch Normalization),那么其与Layer Normalization的区别是什么呢?为什么我们这里用Layer…而不用Batch…呢?

Batch Normalization作用与一个batch中的某一个channel,而Layer Normalization作用于batch中的一条样本。


d6c419aac3034cebbe0e889b658c4200.png

比如我们有一个维度为(batch_size, c, wxh)的输入:

[
  [  # batch_size=0
        [x11, x12, x13, x14],
      [x21, x22, x23, x24],
      [x31, x32, x33, x34],
    ],
    [  # batch_size=1
        [y11, y12, y13, y14],
      [y21, y22, y23, y24],
      [y31, y32, y33, y34],  
    ]
].shape = (2, 3, 4)

如果我们进行Batch Normalization,我们将把x11-x31,y11-y31放在一起进行归一化,而Layer Normalization会把X11-X34即batch_size=0的整个矩阵进行归一化。其中

1.2 Feed Forword Network

Feed Forword Network是一个两层的全连接层,第一层以ReLU作为激活函数,第二层不适用激活函数,其计算公式如下:

ba978ce36255463498d2f5ab726cac5b.png

其中X是Feed Forword Network的输入,最终的输出与输入X一致。

1.3 xN

上面介绍了一个Encoder单元的工作,其实整个Encoder是多个这种单元嵌套组成的,计算公式如下图

1cd10badd9f04c00bbe7dcb775461f47.png

2 Decoder

输出:位置i的输出词的概率分布

输入:Encoder的输出和第i-1位置Decoder的输出。

编码可以并行计算,一次性全部Encoding出来。但解码不是一次把所有序列解出来的,而是像 一样一个一个解出来的,因为要用上一个位置的输入当作attention的query 。

还是从下往上了解Decoder的结构:

  1. 1. 第i-1个Decoder的输出输入Mask Multi-Head Attention
  2. 2. 将第一步中的输出输入到第二个Multi-Head Attention生成矩阵Q,将Encoder的输出输入到第二个Multi-Head Attention生成矩阵K个V,这样每一个单词都可以利用到Encoder的所有单词信息
  3. 3. 将结果输入到Feed Forward Network网络中,这与Encoder一致

2.1 Masked Multi-Head Attention

相比Multi-Head Attention,这里主要是在Scale之后,Softmax之前进行了一次Mask操作。

8f49cacd533e4a00a743e2e3a480793e.png

主要的作用是使attention只会attend on已经产生的sequence

那么为什么使用mask操作呢?

Decoder 可以在训练的过程中使用Teacher Forcing并且并行化训练,即将正确的单词序列 ( I have a cat) 和对应输出 (I have a cat ) 传递到 Decoder。那么在预测第 i 个输出时,就要将第 i+1 之后的单词掩盖住。

具体如何做的呢,举例子进行说明?

假设我们有一个输入矩阵X=“ I have a cat”,映射为(0, 1, 2, 3, 4),通过Embedding后维度为(5, N)

按照self-attention的操作生成矩阵A(未进行softmax),维度为(5, 5),如下图所示

b300a5d1a545450693ea4f2e3c2c4bd0.png

我们假设有一个Mask矩阵,其大小为(5, 5),如下图所示。

c4280e6f947f4248953af864fb202e29.png

我们将生成的A矩阵按位与Mask相乘,得到遮掩后的矩阵A’。其中单词0只能使用单词0,单词1只能使用单词0和单词1…

36b875bf988f42e282b2effdfc0b181f.png

将生成的 Mask QK^T与矩阵V相乘,得到只与前i-1个单词相关的矩阵信息

045c56fdb23541ca8063b75cd0a908e2.png

  1. 接下来的操作与Multi-Head Attention一致

3 Softmax输出预测

通过Feed Forward Network的结果通过一个全连接层得到结果如下:

ff811868d55e4f70ba4cb3923ce35d9f.png

我们进行softmax预测下一个单词

26f91686bc0b433fa3dd530f6a352a5c.png

至此,transformer的网络总结结束,还需要对代码进行总结。

4 参考文章

https://blog.csdn.net/qq_37541097/article/details/117691873?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522168241225816800182716096%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=168241225816800182716096&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2blogfirst_rank_ecpm_v1~rank_v31_ecpm-3-117691873-null-null.blog_rank_default&utm_term=transformer&spm=1018.2226.3001.4450

https://zhuanlan.zhihu.com/p/340149804


相关文章
|
1月前
|
机器学习/深度学习 人工智能 数据可视化
图解Transformer——注意力计算原理
图解Transformer——注意力计算原理
50 0
|
3月前
|
PyTorch 算法框架/工具 C++
Bert Pytorch 源码分析:二、注意力层
Bert Pytorch 源码分析:二、注意力层
51 0
|
机器学习/深度学习 自然语言处理 算法
Transformer 模型:入门详解(1)
动动发财的小手,点个赞吧!
12947 1
Transformer 模型:入门详解(1)
|
1月前
|
机器学习/深度学习 自然语言处理 语音技术
Transformer框架
Transformer框架
21 1
|
5月前
|
机器学习/深度学习 自然语言处理
深度剖析Transformer核心思想 "Attention Is All You Need"
深度剖析Transformer核心思想 "Attention Is All You Need"
128 1
|
4月前
|
机器学习/深度学习 人工智能 关系型数据库
简化版Transformer :Simplifying Transformer Block论文详解
在这篇文章中我将深入探讨来自苏黎世联邦理工学院计算机科学系的Bobby He和Thomas Hofmann在他们的论文“Simplifying Transformer Blocks”中介绍的Transformer技术的进化步骤。这是自Transformer 开始以来,我看到的最好的改进。
60 0
|
8月前
|
机器学习/深度学习 自然语言处理 索引
【Transformer系列(4)】Transformer模型结构超详细解读
【Transformer系列(4)】Transformer模型结构超详细解读
189 0
【Transformer系列(4)】Transformer模型结构超详细解读
|
8月前
|
机器学习/深度学习 算法 PyTorch
【vision transformer】DETR原理及代码详解(一)
【vision transformer】DETR原理及代码详解
540 0
|
8月前
【vision transformer】DETR原理及代码详解(二)
【vision transformer】DETR原理及代码详解
54 0
|
8月前
|
SQL API
【vision transformer】DETR原理及代码详解(四)
【vision transformer】DETR原理及代码详解
257 0

相关实验场景

更多