【RetNet】论文解读:Retentive Network: A Successor to Transformer for Large Language Models

简介: 【RetNet】论文解读:Retentive Network: A Successor to Transformer for Large Language Models

Retentive Network: A Successor to Transformer for Large Language Models

这是由微软和清华一起发布的论文,在这篇文章中,作者提出了一个代替transformer的基础架构:retentive network,该架构可以同时达到 training parallelism, low-cost inference, good performance. 并从理论上推导出了recurrence与attention之间的联系,这种retentive机制在序列模型中有三种计算范式:parallel, recurrent, and chunkwise recurrent. 其中并行表示允许训练并行性。循环表示可以实现低成本的O (1)推理,从而在不牺牲性能的情况下提高了解码需要的吞吐量、并降低了延迟和GPU内存。块级递归表示促进了具有线性复杂度的高效长序列建模,其中每个块都被并行编码,同时递归地总结块。

一、完整代码

https://github.com/microsoft/unilm/tree/master/retnet

二、论文解读

2.1 介绍

论文开头提出了一个不可能三角,分别是training parallelism, low-cost inference, good performance;以往的架构只能获得三种优势中的两种,而RetNet可以全部获得;

首先是Linear Transformer :其主要处理的方式是对k和v进行处理,例如[Linformer]论文实现:Linformer: Self-Attention with Linear Complexity_linformer网络结构-CSDN博客是通过证明self-attention是一个低秩矩阵来减少k和v的维度进而得到线性复杂度的效果,即low-cost inference,但是其降低了Transformer的效果;

第二个是Recurrent Network,随着不断的优化,其最大的缺点就是不能并行训练;

最后一个是Transformer,其最大的不足便是复杂度是 O ( n 2 ) O(n^2) O(n2),这导致序列长度的增加增加了GPU内存消耗和延迟,并降低了推理速度。

这里论文提出的RetNet,可以同时获得training parallelism, low-cost inference, good performance三种优秀的性质,其通过采用一种multi-scale retention 机制去替换multi-head attention

作者通过实验表明,RetNet在scaling curves序列长度和in-context learning上下文学习方面相较于Transformer是持续超过的状态,同时,RetNet的inference cost 是 O ( 1 ) O(1) O(1) ;对于7B参数量和8k序列长度的语言模型,RetNet的解码速度比具有键值缓存的Transformer快8.4×,节省了70%的内存。在训练过程中,RetNet还比Transformer节省了25-50%的内存和7×的加速,并且在highly-optimized FlashAttention方面具有优势。此外,RetNet的推理延迟对批处理大小不敏感,允许巨大的吞吐量,认为是遥遥领先的;

2.2 Retentive Networks

在介绍框架之前,应该对Transformer和RoPE旋转位置编码有一定的了解,可以看下面两篇博客:

[transformer]论文实现:Attention Is All You Need_transformer vaswani论文-CSDN博客

[RoFormer]论文实现:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING_roformer: enhanced transformer with rotaryposition-CSDN博客

Retentive Network 的大致结构和 Transformer 是一致的,其不同点主要在于利用了 multi-scale retention MSR 替换掉了 multi-heads attention MHA;Retentive Network 中的Retention有三种计算方法:The Parallel Representation of Retention, The Recurrent Representation of Retention, The Chunkwise Recurrent Representation of Retention;论文中主要介绍了前面两种,并行计算的结果和循环计算的结果是一致的;

在这里我们先给出计算过程,再去证明为什么并行计算的结果和循环计算的结果一致;

Retention
The Parallel Representation of Retention

论文中相关内容如下:

可以看到在计算 Q 和 K的时候,多出现了一个 Θ,这里的 image.png 通过这篇博客[RoFormer]论文实现:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING_roformer: enhanced transformer with rotaryposition-CSDN博客,可以发现其其本质就是RoPE中的一个旋转矩阵;

这里插入介绍一下旋转矩阵的快速计算技巧:

结合下面代码:

def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)
def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)

theta_shift函数的返回值就是旋转矩阵,得到的旋转矩阵如图:

而 Θ ‾ \overline{\Theta} Θ 表示 Θ \Theta Θ共轭,其旋转矩阵相对于 Θ \Theta Θ,cos的结果不变,sin的结果变为相反数,从上图中可以明显的发现两个旋转矩阵是相互转置的;

从原文中看,上面的红框是很自然的推导出下面的红框的,其中十字架是共轭转置的意思,从这篇博客中得到的解释有出入关于RoPE旋转位置编码的理解-CSDN博客

RoPE论文中原文可以用下面等式化等号:

image.png

可以看到 image.png 用矩阵进行了替换,所以说如果按照这样解释的话 image.png 在并入转置矩阵中的时候,其应该是自动发生了共轭关系,应该是 image.png

从源代码中也可以发现:

def parallel_forward(self, qr, kr, v, mask):
  bsz, tgt_len, embed_dim = v.size()
  vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
  qk_mat = qk_mat * mask
  # invariant after normalization
  qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4)
  output = torch.matmul(qk_mat, vr)
  output = output.transpose(1, 2)
  return output
    
def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
    ...
  qr = theta_shift(q, sin, cos) ## here
    kr = theta_shift(k, sin, cos) ## here
    output = self.parallel_forward(qr, kr, v, inner_mask)

其对 Q和 K 采取的处理方式是一样的,所以

从我对 image.png 的解释上来说,这里的 image.png 应该还是 Θ;这样全文就通顺了;

这里还有一个Dnm 其中在 n ≥ m中 γ应该是大于0小于1的,这样就可以根据距离削弱关系,而在 n < m nn<m中出现0表示一个自回归关系,也就是说只能注意到前面的内容,无法看到后面的内容;

类似于达到如下图一样的结果:

这就是Retention的并行表示内容;

The Recurrent Representation of Retention

有计算公式:

Q=(XWQ)⊙ΘK=(XWK)⊙Θ¯V=XWVSn=γSn−1+KnTVnRetention(Xn)=QnSn,n=1,…,|x|

\begin{align} &Q=(XW_Q)\odot \Theta \\ &K=(XW_K)\odot \overline{\Theta} \\ &V=XW_V \\ &S_n= \gamma S_{n-1} + K_n^TV_n \\ &Retention(X_n)=Q_nS_n, n=1,\dots,|x| \\ \end{align}

证明结果一致

从循环推导并行:

image.png

这里要注意的是  A的维度是 d × d , Kn和  Qn的维度是 1 × d

接着这里进行了一步处理就是对角化,这里原文中的字母看着别扭,不用他们的字母: image.png

如果没有重根,任何矩阵都是可以对角化的,这里做了一个小假设,就是 A 是可以对角化的;其中 P 和  Λ的维度都是 d × d 这样的话,原式子可以化简为

  image.png

同时,由于

  image.png

由于 WQ,WK都是学习参数,我们可以把  P并入其中,再化简就可以得到: image.png

image.png 并入 image.png 中,由于 Λ 是对角矩阵得到:

image.png

这里再对 Λ进行处理,论文中令

image.png

其中 γ和  θ是一个 d维向量,但是我怎么也想不出来这个是怎么变成 d × d 维矩阵的,如果把 γ看做一个标量, image.png 看作一个对角矩阵,是很容易的就可以推出最下面这个公式: image.png

但是看到 image.png 我左看右看上看下看都觉得不舒服,感觉就是哪里出现了问题;从理论上是这样的,但是从代码实现上, KQ采取的是同样的处理方式;

这里我们把 image.png

其中

由于 R 的性质

  image.png

我们可以得到:

image.png

直到我看了复数矩阵相关的内容,对复数矩阵进行转置是自动要进行共轭操作的,从这种角度进行解释就是说原文 image.png 是错误的,应该改为 image.png ,但这只是我的猜测;如果是真错了,清华大学我真想@@他几下;

推导结束!

The Chunkwise Recurrent Representation of Retention

这里只贴一下,看起来很复杂但是用图一下就能说明;

Gated Multi-Scale Retention

Gated Multi-Scale Retention MSR 多尺度保留机制类似于多头注意力机制,模型的维度为 image.png ,每个头的维度为  d,一共有 image.png 个头,每个头和多头注意力机制的每个头一样使用不同的 image.png

由于使用了不同的 γi,导致  Retention的尺度会发生变化,需要进行Normalization;

这里利用GroupNorm的尺度不变性来Normalization提高保留层的数值精度,由于 image.png

image.png

由于尺度不变性,这并不影响最终结果

Overall Architecture of Retention Networks

类似于Transformer的红框部分:

最主要区别是把MHA换成了MSR;

2.3 对比

三、过程实现

略;

四、整体总结

网络上说抄袭的rwkv模型…


目录
相关文章
|
机器学习/深度学习 人工智能 自然语言处理
Paper:GPT-3《 Language Models are Few-Shot Learners》的翻译与解读(四)
Paper:GPT-3《 Language Models are Few-Shot Learners》的翻译与解读
|
5天前
|
机器学习/深度学习 数据采集 自然语言处理
[GPT-2]论文解读:Language Models are Unsupervised Multitask Learners
[GPT-2]论文解读:Language Models are Unsupervised Multitask Learners
14 1
|
5天前
|
机器学习/深度学习 自然语言处理 并行计算
[Bert]论文实现:BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
[Bert]论文实现:BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
16 1
|
5天前
|
机器学习/深度学习 自然语言处理 TensorFlow
[Character Embedding]论文实现:Text Understanding from Scratch
[Character Embedding]论文实现:Text Understanding from Scratch
11 2
|
5天前
|
机器学习/深度学习 JSON 自然语言处理
[GPT-1]论文实现:Improving Language Understanding by Generative Pre-Training
[GPT-1]论文实现:Improving Language Understanding by Generative Pre-Training
45 1
|
5天前
|
Python
[UNILM]论文实现:Unified Language Model Pre-training for Natural Language.........
[UNILM]论文实现:Unified Language Model Pre-training for Natural Language.........
12 0
|
9月前
【COT】Chain-of-Thought Prompting Elicits Reasoning in Large Language Models
【COT】Chain-of-Thought Prompting Elicits Reasoning in Large Language Models
131 0
|
9月前
|
机器学习/深度学习 资源调度 算法
【RLchina第四讲】Model-Based Reinforcement Learning(上)
【RLchina第四讲】Model-Based Reinforcement Learning(上)
241 0
|
9月前
|
机器学习/深度学习 算法
【RLchina第四讲】Model-Based Reinforcement Learning(下)
【RLchina第四讲】Model-Based Reinforcement Learning(下)
118 0
|
10月前
|
机器学习/深度学习 算法 计算机视觉
【计算机视觉 | 目标检测】Open-vocabulary Object Detection via Vision and Language Knowledge Distillation
在这项工作中,我们考虑借用预训练的开放词汇分类模型中的知识来实现open vocabulary检测。