论文:XLNet: Generalized Autoregressive Pretraining for Language Understanding
作者:Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
时间:2020
1. 介绍
自回归语言模型(AR):具有代表性的是GPT;
AR模型是单向的,其要么是从左到右,要么是从右到左,不能考虑到所有的上下文信息,这里以从左到右为例,其目的就是最大化:
其中e(xt)表示 xt的位置编码,而 hθ(x1:t−1)表示1到t-1位置上神经网络的非线性函数,或者说上下文表示,这样得到的结果经过softmax就得到了要预测的值;
这里有一点儿缺陷,就是AR模型不能结合上下文;
自编码模型(AE):具有代表性的是BERT;
AE模型是双向的,其是通过一个整体的上下文表示来预测另一整体的上下文表示,每一个位置上的token都可以获得任何位置上的token的信息,其训练方式有两种,一种是mask
,一种是NSP
,但是NSP
被认为效果不好,这里我们以mask
来进行分析;首先,其目的是最大化:
其中 mt=1表示 xt被mask掉了, Hθ(x^)t表示隐藏t位置的隐藏状态, 表示mask的原来的token, 表示带有mask的序列;
从公式中可以看出,这里有一些缺陷:
- 假设mask独立,因为这里是直接累计的形式,其本质是假设各个mask之间是独立的;
- 训练时input里含有mask,而在下游任务中并不会出现;
2. 模型架构
尝试建立一个模型去结合AR的优点和AE的优点;
2.1 Permutation Language Modeling
长度为T
的序列,一共有T!
总不同的排列,如果每个位置的token都可能为序列中的任意一个token,由于参数共享这一原理,就相当于结合了上下文;其本质是使用类GPT模型去构造一个集合类BERT和类GPT模型;
如图:
图中利用了transformer-xl
机制,所以有memory
;
看起来挺合理的,其目标函数如下,从排列中抽样来训练模型;
但是在这里会出现一种情况,就是相同的序列可以会预测不同的token;如图:
如何解决这个问题呢,这里就要用到改变编码方式的方法;文章中叫做Architecture: Two-Stream Self-Attention for Target-Aware Representations
2.2 Two-Stream Self-Attention for Target-Aware Representations
Architecture: Two-Stream Self-Attention for Target-Aware Representations
其采取的方式就是集合content stream attention 和 query stream attention的方式:
content stream attention: hθ(xz≤t),和普通的transformer
采取的注意力机制一样;这种表示可以看到上下文和自己;
query stream attention: gθ(xz≤t),和普通的不同,这种表示只可以看到自己的位置和上下文的信息,不包含自己;
如图所示:
论文中非常详细,下图是content stream的流程:
下图是query stream的流程:
模型架构用公式表示如下所示:
在这里要提的一点是,这里不是在输入层进行排序,而是在attention中的mask进行;
同时,这里如果考虑所有的排序会非常不现实,毕竟阶乘比指数更加可怕,这里采用的办法是partial prediction
,即部分预测;
总结就是在c位置时,只选择 个排列进行预测,而不是考虑全部;
2.3 超参数设置
训练时:
可以看到这是一个较大的模型,这里有24层layers
微调时:
说明作者不心虚~
2.4 相对编码
这里采取的编码方式不同于BERT
,这里采用的是transformer-xl
的位置编码方式,相对位置编码;
2.5 模型比较
下面是一些模型的对比:
可以看到,XLNet
要显著的优于BERT
,和RoBERTa
比,其性能还是要高于RoBERTa
的;
消融实验结果如下:
从中我们可以看到XLNet
要明显优于BERT
,但要是删除memory
机制即transformer-xl
性能会明显下降,第6 - 7行表明,span-based pred
和bidirectional data
在XLNet
中都发挥着重要作用,nect-sentence prediction
效果好像并不显著,这里在原始的XLNet
中并没有加入NSP
任务进行预训练;
3. 总结
XLNet
主要就是两个东西,一个是Permutation Language Modeling,一个是transformer-xl;感觉性能相对于roberta
也没提升多少,这个模型的架构应该是不太行;