【RetNet】论文解读:Retentive Network: A Successor to Transformer for Large Language Models

简介: 【RetNet】论文解读:Retentive Network: A Successor to Transformer for Large Language Models

Retentive Network: A Successor to Transformer for Large Language Models

这是由微软和清华一起发布的论文,在这篇文章中,作者提出了一个代替transformer的基础架构:retentive network,该架构可以同时达到 training parallelism, low-cost inference, good performance. 并从理论上推导出了recurrence与attention之间的联系,这种retentive机制在序列模型中有三种计算范式:parallel, recurrent, and chunkwise recurrent. 其中并行表示允许训练并行性。循环表示可以实现低成本的O (1)推理,从而在不牺牲性能的情况下提高了解码需要的吞吐量、并降低了延迟和GPU内存。块级递归表示促进了具有线性复杂度的高效长序列建模,其中每个块都被并行编码,同时递归地总结块。

一、完整代码

https://github.com/microsoft/unilm/tree/master/retnet

二、论文解读

2.1 介绍

论文开头提出了一个不可能三角,分别是training parallelism, low-cost inference, good performance;以往的架构只能获得三种优势中的两种,而RetNet可以全部获得;

首先是Linear Transformer :其主要处理的方式是对k和v进行处理,例如[Linformer]论文实现:Linformer: Self-Attention with Linear Complexity_linformer网络结构-CSDN博客是通过证明self-attention是一个低秩矩阵来减少k和v的维度进而得到线性复杂度的效果,即low-cost inference,但是其降低了Transformer的效果;

第二个是Recurrent Network,随着不断的优化,其最大的缺点就是不能并行训练;

最后一个是Transformer,其最大的不足便是复杂度是 O ( n 2 ) O(n^2) O(n2),这导致序列长度的增加增加了GPU内存消耗和延迟,并降低了推理速度。

这里论文提出的RetNet,可以同时获得training parallelism, low-cost inference, good performance三种优秀的性质,其通过采用一种multi-scale retention 机制去替换multi-head attention

作者通过实验表明,RetNet在scaling curves序列长度和in-context learning上下文学习方面相较于Transformer是持续超过的状态,同时,RetNet的inference cost 是 O ( 1 ) O(1) O(1) ;对于7B参数量和8k序列长度的语言模型,RetNet的解码速度比具有键值缓存的Transformer快8.4×,节省了70%的内存。在训练过程中,RetNet还比Transformer节省了25-50%的内存和7×的加速,并且在highly-optimized FlashAttention方面具有优势。此外,RetNet的推理延迟对批处理大小不敏感,允许巨大的吞吐量,认为是遥遥领先的;

2.2 Retentive Networks

在介绍框架之前,应该对Transformer和RoPE旋转位置编码有一定的了解,可以看下面两篇博客:

[transformer]论文实现:Attention Is All You Need_transformer vaswani论文-CSDN博客

[RoFormer]论文实现:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING_roformer: enhanced transformer with rotaryposition-CSDN博客

Retentive Network 的大致结构和 Transformer 是一致的,其不同点主要在于利用了 multi-scale retention MSR 替换掉了 multi-heads attention MHA;Retentive Network 中的Retention有三种计算方法:The Parallel Representation of Retention, The Recurrent Representation of Retention, The Chunkwise Recurrent Representation of Retention;论文中主要介绍了前面两种,并行计算的结果和循环计算的结果是一致的;

在这里我们先给出计算过程,再去证明为什么并行计算的结果和循环计算的结果一致;

Retention
The Parallel Representation of Retention

论文中相关内容如下:

可以看到在计算 Q 和 K的时候,多出现了一个 Θ,这里的 image.png 通过这篇博客[RoFormer]论文实现:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING_roformer: enhanced transformer with rotaryposition-CSDN博客,可以发现其其本质就是RoPE中的一个旋转矩阵;

这里插入介绍一下旋转矩阵的快速计算技巧:

结合下面代码:

def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)
def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)

theta_shift函数的返回值就是旋转矩阵,得到的旋转矩阵如图:

而 Θ ‾ \overline{\Theta} Θ 表示 Θ \Theta Θ共轭,其旋转矩阵相对于 Θ \Theta Θ,cos的结果不变,sin的结果变为相反数,从上图中可以明显的发现两个旋转矩阵是相互转置的;

从原文中看,上面的红框是很自然的推导出下面的红框的,其中十字架是共轭转置的意思,从这篇博客中得到的解释有出入关于RoPE旋转位置编码的理解-CSDN博客

RoPE论文中原文可以用下面等式化等号:

image.png

可以看到 image.png 用矩阵进行了替换,所以说如果按照这样解释的话 image.png 在并入转置矩阵中的时候,其应该是自动发生了共轭关系,应该是 image.png

从源代码中也可以发现:

def parallel_forward(self, qr, kr, v, mask):
  bsz, tgt_len, embed_dim = v.size()
  vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
  qk_mat = qk_mat * mask
  # invariant after normalization
  qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4)
  output = torch.matmul(qk_mat, vr)
  output = output.transpose(1, 2)
  return output
    
def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
    ...
  qr = theta_shift(q, sin, cos) ## here
    kr = theta_shift(k, sin, cos) ## here
    output = self.parallel_forward(qr, kr, v, inner_mask)

其对 Q和 K 采取的处理方式是一样的,所以

从我对 image.png 的解释上来说,这里的 image.png 应该还是 Θ;这样全文就通顺了;

这里还有一个Dnm 其中在 n ≥ m中 γ应该是大于0小于1的,这样就可以根据距离削弱关系,而在 n < m nn<m中出现0表示一个自回归关系,也就是说只能注意到前面的内容,无法看到后面的内容;

类似于达到如下图一样的结果:

这就是Retention的并行表示内容;

The Recurrent Representation of Retention

有计算公式:

Q=(XWQ)⊙ΘK=(XWK)⊙Θ¯V=XWVSn=γSn−1+KnTVnRetention(Xn)=QnSn,n=1,…,|x|

\begin{align} &Q=(XW_Q)\odot \Theta \\ &K=(XW_K)\odot \overline{\Theta} \\ &V=XW_V \\ &S_n= \gamma S_{n-1} + K_n^TV_n \\ &Retention(X_n)=Q_nS_n, n=1,\dots,|x| \\ \end{align}

证明结果一致

从循环推导并行:

image.png

这里要注意的是  A的维度是 d × d , Kn和  Qn的维度是 1 × d

接着这里进行了一步处理就是对角化,这里原文中的字母看着别扭,不用他们的字母: image.png

如果没有重根,任何矩阵都是可以对角化的,这里做了一个小假设,就是 A 是可以对角化的;其中 P 和  Λ的维度都是 d × d 这样的话,原式子可以化简为

  image.png

同时,由于

  image.png

由于 WQ,WK都是学习参数,我们可以把  P并入其中,再化简就可以得到: image.png

image.png 并入 image.png 中,由于 Λ 是对角矩阵得到:

image.png

这里再对 Λ进行处理,论文中令

image.png

其中 γ和  θ是一个 d维向量,但是我怎么也想不出来这个是怎么变成 d × d 维矩阵的,如果把 γ看做一个标量, image.png 看作一个对角矩阵,是很容易的就可以推出最下面这个公式: image.png

但是看到 image.png 我左看右看上看下看都觉得不舒服,感觉就是哪里出现了问题;从理论上是这样的,但是从代码实现上, KQ采取的是同样的处理方式;

这里我们把 image.png

其中

由于 R 的性质

  image.png

我们可以得到:

image.png

直到我看了复数矩阵相关的内容,对复数矩阵进行转置是自动要进行共轭操作的,从这种角度进行解释就是说原文 image.png 是错误的,应该改为 image.png ,但这只是我的猜测;如果是真错了,清华大学我真想@@他几下;

推导结束!

The Chunkwise Recurrent Representation of Retention

这里只贴一下,看起来很复杂但是用图一下就能说明;

Gated Multi-Scale Retention

Gated Multi-Scale Retention MSR 多尺度保留机制类似于多头注意力机制,模型的维度为 image.png ,每个头的维度为  d,一共有 image.png 个头,每个头和多头注意力机制的每个头一样使用不同的 image.png

由于使用了不同的 γi,导致  Retention的尺度会发生变化,需要进行Normalization;

这里利用GroupNorm的尺度不变性来Normalization提高保留层的数值精度,由于 image.png

image.png

由于尺度不变性,这并不影响最终结果

Overall Architecture of Retention Networks

类似于Transformer的红框部分:

最主要区别是把MHA换成了MSR;

2.3 对比

三、过程实现

略;

四、整体总结

网络上说抄袭的rwkv模型…


目录
相关文章
|
机器学习/深度学习 算法 数据库
Dataset之LFW:LFW人脸数据库的简介、安装、使用方法之详细攻略
Dataset之LFW:LFW人脸数据库的简介、安装、使用方法之详细攻略
Dataset之LFW:LFW人脸数据库的简介、安装、使用方法之详细攻略
|
机器学习/深度学习 Python
【机器学习】包裹式特征选择之递归特征消除法
【机器学习】包裹式特征选择之递归特征消除法
2831 4
|
Python
Python 将PowerPoint (PPT/PPTX) 转为HTML
使用Python将PowerPoint转换为HTML以适应网络分享。需安装`Spire.Presentation for Python`库,通过`pip install Spire.Presentation`。示例包括:1) 全部转换,使用`Presentation.SaveToFile()`方法;2) 转换特定幻灯片,通过`Presentation.Slides[]`获取幻灯片再保存。代码示例展示了具体操作步骤。
1227 6
|
5月前
|
机器学习/深度学习 传感器 算法
Python | K折交叉验证的参数优化的弹性网络回归预测及可视化算法
本教程介绍基于Python的K折交叉验证与参数优化的弹性网络回归预测算法,涵盖贝叶斯、随机及网格搜索三种调参方法,结合SHAP分析、密度散点图与热力图等可视化技术,适用于多领域回归任务,代码及数据完整可复现。
307 0
|
云安全 人工智能 自然语言处理
|
机器学习/深度学习 数据可视化 算法
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
神经常微分方程(Neural ODEs)是深度学习领域的创新模型,将神经网络的离散变换扩展为连续时间动力系统。本文基于Torchdyn库介绍Neural ODE的实现与训练方法,涵盖数据集构建、模型构建、基于PyTorch Lightning的训练及实验结果可视化等内容。Torchdyn支持多种数值求解算法和高级特性,适用于生成模型、时间序列分析等领域。
757 77
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
|
12月前
|
机器学习/深度学习 运维 监控
实时异常检测实战:Flink+PAI 算法模型服务化架构设计
本文深入探讨了基于 Apache Flink 与阿里云 PAI 构建的实时异常检测系统。内容涵盖技术演进、架构设计、核心模块实现及金融、工业等多领域实战案例,解析流处理、模型服务化、状态管理等关键技术,并提供性能优化与高可用方案,助力企业打造高效智能的实时异常检测平台。
1118 1
|
机器学习/深度学习 算法 PyTorch
深度学习笔记(十三):IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU、WIOU损失函数分析及Pytorch实现
这篇文章详细介绍了多种用于目标检测任务中的边界框回归损失函数,包括IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU和WIOU,并提供了它们的Pytorch实现代码。
5472 1
深度学习笔记(十三):IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU、WIOU损失函数分析及Pytorch实现
|
机器学习/深度学习 测试技术 计算机视觉
注意力机制汇总,包括SE、CBAM、ECA等
注意力机制汇总,包括SE、CBAM、ECA等
2952 1