Retentive Network: A Successor to Transformer for Large Language Models
这是由微软和清华一起发布的论文,在这篇文章中,作者提出了一个代替transformer的基础架构:retentive network,该架构可以同时达到 training parallelism, low-cost inference, good performance. 并从理论上推导出了recurrence与attention之间的联系,这种retentive机制在序列模型中有三种计算范式:parallel, recurrent, and chunkwise recurrent. 其中并行表示允许训练并行性。循环表示可以实现低成本的O (1)推理,从而在不牺牲性能的情况下提高了解码需要的吞吐量、并降低了延迟和GPU内存。块级递归表示促进了具有线性复杂度的高效长序列建模,其中每个块都被并行编码,同时递归地总结块。
一、完整代码
https://github.com/microsoft/unilm/tree/master/retnet
二、论文解读
2.1 介绍
论文开头提出了一个不可能三角,分别是training parallelism, low-cost inference, good performance;以往的架构只能获得三种优势中的两种,而RetNet可以全部获得;
首先是Linear Transformer :其主要处理的方式是对k和v进行处理,例如[Linformer]论文实现:Linformer: Self-Attention with Linear Complexity_linformer网络结构-CSDN博客是通过证明self-attention是一个低秩矩阵来减少k和v的维度进而得到线性复杂度的效果,即low-cost inference,但是其降低了Transformer的效果;
第二个是Recurrent Network,随着不断的优化,其最大的缺点就是不能并行训练;
最后一个是Transformer,其最大的不足便是复杂度是 O ( n 2 ) O(n^2) O(n2),这导致序列长度的增加增加了GPU内存消耗和延迟,并降低了推理速度。
这里论文提出的RetNet,可以同时获得training parallelism, low-cost inference, good performance三种优秀的性质,其通过采用一种multi-scale retention 机制去替换multi-head attention
作者通过实验表明,RetNet在scaling curves
序列长度和in-context learning
上下文学习方面相较于Transformer是持续超过的状态,同时,RetNet的inference cost 是 O ( 1 ) O(1) O(1) ;对于7B参数量和8k序列长度的语言模型,RetNet的解码速度比具有键值缓存的Transformer快8.4×,节省了70%的内存。在训练过程中,RetNet还比Transformer节省了25-50%的内存和7×的加速,并且在highly-optimized FlashAttention
方面具有优势。此外,RetNet的推理延迟对批处理大小不敏感,允许巨大的吞吐量,认为是遥遥领先的;
2.2 Retentive Networks
在介绍框架之前,应该对Transformer和RoPE旋转位置编码有一定的了解,可以看下面两篇博客:
[transformer]论文实现:Attention Is All You Need_transformer vaswani论文-CSDN博客
Retentive Network 的大致结构和 Transformer 是一致的,其不同点主要在于利用了 multi-scale retention MSR
替换掉了 multi-heads attention MHA
;Retentive Network 中的Retention
有三种计算方法:The Parallel Representation of Retention, The Recurrent Representation of Retention, The Chunkwise Recurrent Representation of Retention;论文中主要介绍了前面两种,并行计算的结果和循环计算的结果是一致的;
在这里我们先给出计算过程,再去证明为什么并行计算的结果和循环计算的结果一致;
Retention
The Parallel Representation of Retention
论文中相关内容如下:
可以看到在计算 Q 和 K的时候,多出现了一个 Θ,这里的 通过这篇博客[RoFormer]论文实现:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING_roformer: enhanced transformer with rotaryposition-CSDN博客,可以发现其其本质就是RoPE中的一个旋转矩阵;
这里插入介绍一下旋转矩阵的快速计算技巧:
结合下面代码:
def rotate_every_two(x): x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) def theta_shift(x, sin, cos): return (x * cos) + (rotate_every_two(x) * sin)
theta_shift
函数的返回值就是旋转矩阵,得到的旋转矩阵如图:
而 Θ ‾ \overline{\Theta} Θ 表示 Θ \Theta Θ共轭,其旋转矩阵相对于 Θ \Theta Θ,cos的结果不变,sin的结果变为相反数,从上图中可以明显的发现两个旋转矩阵是相互转置的;
从原文中看,上面的红框是很自然的推导出下面的红框的,其中十字架是共轭转置的意思,从这篇博客中得到的解释有出入关于RoPE旋转位置编码的理解-CSDN博客:
RoPE论文中原文可以用下面等式化等号:
可以看到 用矩阵进行了替换,所以说如果按照这样解释的话 在并入转置矩阵中的时候,其应该是自动发生了共轭关系,应该是
从源代码中也可以发现:
def parallel_forward(self, qr, kr, v, mask): bsz, tgt_len, embed_dim = v.size() vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len qk_mat = qk_mat * mask # invariant after normalization qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4) output = torch.matmul(qk_mat, vr) output = output.transpose(1, 2) return output def forward( self, x, rel_pos, chunkwise_recurrent=False, incremental_state=None ): ... qr = theta_shift(q, sin, cos) ## here kr = theta_shift(k, sin, cos) ## here output = self.parallel_forward(qr, kr, v, inner_mask)
其对 Q和 K 采取的处理方式是一样的,所以
从我对 的解释上来说,这里的 应该还是 Θ;这样全文就通顺了;
这里还有一个Dnm 其中在 n ≥ m中 γ应该是大于0小于1的,这样就可以根据距离削弱关系,而在 n < m nn<m中出现0表示一个自回归关系,也就是说只能注意到前面的内容,无法看到后面的内容;
类似于达到如下图一样的结果:
这就是Retention的并行表示内容;
The Recurrent Representation of Retention
有计算公式:
Q=(XWQ)⊙ΘK=(XWK)⊙Θ¯V=XWVSn=γSn−1+KnTVnRetention(Xn)=QnSn,n=1,…,|x|
\begin{align} &Q=(XW_Q)\odot \Theta \\ &K=(XW_K)\odot \overline{\Theta} \\ &V=XW_V \\ &S_n= \gamma S_{n-1} + K_n^TV_n \\ &Retention(X_n)=Q_nS_n, n=1,\dots,|x| \\ \end{align}证明结果一致
从循环推导并行:
这里要注意的是 A的维度是 d × d , Kn和 Qn的维度是 1 × d
接着这里进行了一步处理就是对角化,这里原文中的字母看着别扭,不用他们的字母:
如果没有重根,任何矩阵都是可以对角化的,这里做了一个小假设,就是 A 是可以对角化的;其中 P 和 Λ的维度都是 d × d 这样的话,原式子可以化简为
同时,由于
由于 WQ,WK都是学习参数,我们可以把 P并入其中,再化简就可以得到:
把 并入 中,由于 Λ 是对角矩阵得到:
这里再对 Λ进行处理,论文中令
其中 γ和 θ是一个 d维向量,但是我怎么也想不出来这个是怎么变成 d × d 维矩阵的,如果把 γ看做一个标量, 看作一个对角矩阵,是很容易的就可以推出最下面这个公式:
但是看到 我左看右看上看下看都觉得不舒服,感觉就是哪里出现了问题;从理论上是这样的,但是从代码实现上, K和 Q采取的是同样的处理方式;
这里我们把
其中
由于 R 的性质
我们可以得到:
直到我看了复数矩阵相关的内容,对复数矩阵进行转置是自动要进行共轭操作的,从这种角度进行解释就是说原文 是错误的,应该改为 ,但这只是我的猜测;如果是真错了,清华大学我真想@@他几下;
推导结束!
The Chunkwise Recurrent Representation of Retention
这里只贴一下,看起来很复杂但是用图一下就能说明;
Gated Multi-Scale Retention
Gated Multi-Scale Retention MSR
多尺度保留机制类似于多头注意力机制,模型的维度为 ,每个头的维度为 d,一共有 个头,每个头和多头注意力机制的每个头一样使用不同的
由于使用了不同的 γi,导致 Retention的尺度会发生变化,需要进行Normalization;
这里利用GroupNorm的尺度不变性来Normalization提高保留层的数值精度,由于
由于尺度不变性,这并不影响最终结果
Overall Architecture of Retention Networks
类似于Transformer的红框部分:
最主要区别是把MHA换成了MSR;
2.3 对比
三、过程实现
略;
四、整体总结
网络上说抄袭的rwkv模型…