Transformer、RNN和SSM的相似性探究:揭示看似不相关的LLM架构之间的联系

本文涉及的产品
实时计算 Flink 版,5000CU*H 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
简介: 通过探索大语言模型(LLM)架构之间的潜在联系,我们可能开辟新途径,促进不同模型间的知识交流并提高整体效率。尽管Transformer仍是主流,但Mamba等线性循环神经网络(RNN)和状态空间模型(SSM)展现出巨大潜力。近期研究揭示了Transformer、RNN、SSM和矩阵混合器之间的深层联系,为跨架构的思想迁移提供了可能。本文深入探讨了这些架构间的相似性和差异,包括Transformer与RNN的关系、状态空间模型在自注意力机制中的隐含作用以及Mamba在特定条件下的重写方式。

通过探索看似不相关的大语言模型(LLM)架构之间的潜在联系,我们可能为促进不同模型间的思想交流和提高整体效率开辟新的途径。

尽管Mamba等线性循环神经网络(RNN)和状态空间模型(SSM)近来备受关注,Transformer架构仍然是LLM的主要支柱。这种格局可能即将发生变化:像Jamba、Samba和Griffin这样的混合架构展现出了巨大的潜力。这些模型在时间和内存效率方面明显优于Transformer,同时在能力上与基于注意力的LLM相比并未显著下降。

近期研究揭示了不同架构选择之间的深层联系,包括Transformer、RNN、SSM和matrix mixers,这一发现具有重要意义,因为它为不同架构间的思想迁移提供了可能。本文将深入探讨Transformer、RNN和Mamba 2,通过详细的代数分析来理解以下几点:

  1. Transformer在某些情况下可以视为RNN(第2节)
  2. 状态空间模型可能隐藏在自注意力机制的掩码中(第4节)
  3. Mamba在特定条件下可以重写为掩码自注意力(第5节)

这些联系不仅有趣,还可能对未来的模型设计产生深远影响。

LLM中的掩码自注意力机制

首先,让我们回顾一下经典的LLM自注意力层的结构:

更详细的结构如下:

自注意力层的工作流程如下:

  1. 将查询矩阵Q和键矩阵K相乘,得到一个L×L的矩阵,包含查询和键的标量积。
  2. 对结果矩阵进行归一化。
  3. 将归一化后的矩阵与L×L的注意力掩码进行元素级乘法。图中展示了默认的因果掩码——左侧的0-1矩阵。这一步骤将较早查询与较晚键的乘积置零,防止注意力机制"看到未来"。
  4. 对结果应用softmax函数。
  5. 最后,将注意力权重矩阵A与值矩阵V相乘。输出的第t行可表示为:

这意味着第i个值是通过"第t个查询对第i个键的注意力权重"来加权的。

这种架构中的多个设计选择都可能被修改。接下来我们将探讨一些可能的变体。

线性化注意力

注意力公式中的Softmax函数确保了值是以和为1的正系数混合的。这种设计保持了某些统计特性,但同时也带来了限制。例如即使我们希望利用结合律,如(QK^T)V = Q(K^TV),也无法突破Softmax的限制。

为什么结合律如此重要?因为改变乘法顺序可能显著影响计算复杂度:

左侧公式需要计算一个L×L矩阵,如果这个矩阵完全显现在内存中,复杂度为O(L²d),内存消耗为O(L²)。右侧公式需要计算一个d×d矩阵,复杂度为O(Ld²),内存消耗为O(d²)。

随着上下文长度L的增加,左侧公式的计算成本rapidly become prohibitively非常的高。为了解决这个问题,我们可以考虑移除Softmax。详细展开带有Softmax的公式:

其中

是Softmax函数。指数函数是主要的障碍,它阻止了我们从中提取任何项。如果我们直接移除指数函数:

那么归一化因子

也随之消失。

这个简化后的公式存在一个问题:q_t^T k_s不能保证为正,这可能导致值以不同符号的系数混合,这在理论上是不合理的。更糟糕的是,分母可能为零,会导致计算崩溃。为了缓解这个问题,我们可以引入一个"良好的"元素级函数φ(称为核函数):

原始研究建议使用φ(x) = 1 + elu(x)作为核函数。

这种注意力机制的变体被称为线性化注意力。它的一个重要优势是允许我们利用结合律:

括号中M, K^T和V之间的关系现在变得相当复杂,不再仅仅是普通的矩阵乘法和元素级乘法。我们将在下一节详细讨论这个计算单元。

如果M是一个因果掩码,即对角线及以下为1,对角线以上为0:

那么计算可以进一步简化:

这可以通过一种简单的递归方式计算:

这是在2020年ICML上首次提出线性化注意力的论文"Transformers are RNNs"。在这个公式中,我们有两个隐藏状态:向量z_t和矩阵h_t(φ(k_t)^T v_t是列向量乘以行向量,得到一个d×d矩阵。

而近期的研究often以更简化的形式呈现线性化注意力,去除了φ函数和分母:

线性化注意力具有两个主要优势:

  1. 作为递归机制,它在推理时相对于序列长度L具有线性复杂度。
  2. 作为Transformer模型,它可以高效地并行训练。

但是你可能会问:如果线性化注意力如此优秀,为什么它没有在所有LLM中广泛应用?我们在讨论注意力的二次复杂度问题?实际上基于线性化注意力的LLM在训练过程中stability较低,且capability略逊于标准自注意力。这可能是因为固定的d×d形状的瓶颈比可调整的L×L形状的瓶颈能传递的信息更少。

进一步探索

RNN和线性化注意力之间的联系在近期的多项研究中得到了重新发现和深入探讨。一个common pattern是使用具有如下更新规则的矩阵隐藏状态:

其中k_t和v_t可以视为某种"键"和"值",RNN层的输出形式为:

这本质上等同于线性注意力。下面两篇论文提供了有趣的一些样例:

1、xLSTM (2024年5月): 该论文提出了对著名的LSTM递归架构的改进。其mLSTM块包含一个矩阵隐藏状态,更新方式如下:

输出通过将这个状态与一个"查询"相乘得到。(注意:该论文的线性代数设置与我们的相反,查询、键和值是列向量而非行向量,因此v_t k_t^T的顺序看起来可能有些奇怪。)

2、Learning to (learn at test time) (2024年7月): 这是另一种具有矩阵隐藏状态的RNN架构,它的隐藏状态W是一个函数的参数,在t的迭代过程中通过梯度下降优化:

这里的设置也是转置的,因此顺序看起来有些不同。尽管数学表达比Wt = W{t-1} + v_t k_t^T更复杂,但可以简化为这种形式。

以上两篇论文我们都详细介绍过,有兴趣的可以自行搜索

注意力掩码

在简化了掩码注意力机制后,我们可以开始探索其潜在的发展方向。一个明显的研究方向是选择不同的下三角矩阵(确保不会"看到未来")作为掩码M,而不是简单的0-1因果掩码。在进行这种探索之前,我们需要解决由此带来的效率问题。

在前一节中,我们使用了一个简单的0-1因果掩码M,这使得递归计算成为可能。但在一般情况下,这种递归技巧不再适用:

系数m_ts不再相同,也不存在将y_3与y_2关联的简单递归公式。因此,对于每个t我们都需要从头开始计算总和,这使得计算复杂度再次变为L的二次方而不是线性的。

解决这个问题的关键在于我们不能使用任意的掩码M,而应该选择特殊的、"良好"的掩码。我们需要那些可以快速与其他矩阵相乘(注意不是元素级乘法)的掩码。为了理解如何从这种特性中获益,让我们详细分析如何高效计算:

首先明确这个表达式的含义:

如果深入到单个索引级别:

为了便于后续讨论,可以用不同的颜色标记索引,而不是块:

现在我们可以提出一个四步算法:

步骤1. 利用K和V创建一个三维张量Z,其中:

(每个轴都标注了其长度。)这一步骤需要O(Ld²)的时间和内存复杂度。值得注意的是,如果我们在洋红色轴t上对这个张量求和,我们将得到矩阵乘积K^T V:

步骤2. 将M乘以这个张量(注意不是元素级乘法)。M乘以Z沿着洋红色轴t的每个"列"。

这正好得到:

将这个结果记为H。接下来只需要将所有内容乘以q,这将在接下来的两个步骤中完成。

步骤3a. 取Q并与H的每个j = const层进行元素级乘法:

这将得到:

这一步骤需要O(Ld²)的时间和内存复杂度。

步骤3b. 沿i轴对结果张量求和:

这一步骤同样需要O(Ld²)的时间和内存复杂度。最终得到了所需的结果:

在这个过程中,最关键的是第二步,我们故意省略了其复杂度分析。一个简单的估计是:

每次矩阵乘法需要O(L²)的复杂度,重复d²次

这将导致一个巨大的O(L²d²)复杂度。但是我们的目标是选择特殊的M,使得将M乘以一个向量的复杂度为O(RL),其中R是某个不太大的常数

例如如果M是0-1因果矩阵,那么与它相乘实际上就是计算累积和,这可以在O(L)时间内完成。但还存在许多其他具有快速向量乘法特性的结构化矩阵选项。

在下一节中将讨论这种矩阵类型的一个重要例子——半可分离矩阵,它与状态空间模型有着密切的联系。

半可分离矩阵与状态空间模型

让我们回顾一下(离散化的)状态空间模型(SSM)的定义。SSM是一类连接1维输入x_t、r维隐藏状态h_t和1维输出u_t的序列模型,其数学表达式如下:

在离散形式中,SSM本质上是一个带有跳跃连接的复杂线性RNN。为了简化后续讨论,我们甚至可以通过设置D_t = 0来忽略跳跃连接。

让我们将SSM表示为单个矩阵乘法:

其中

M是一个下三角矩阵,类似于我们之前讨论的注意力掩码。

这种类型的矩阵具有一个重要的优势:

一个L × L的下三角矩阵,如果其元素可以以这种方式表示,则可以使用O(rL)的内存存储,并且具有O(rL)的矩阵-向量乘法复杂度,而不是默认的O(L²)。

这意味着每个状态空间模型都对应一个结构化的注意力掩码M,可以在具有线性化注意力的高效Transformer模型中使用。

即使没有周围的查询-键-值机制,半可分离矩阵M本身已经相当复杂和富有表现力。它本身可能就是一个掩码注意力机制。我们将在下一节中详细探讨这一点。

状态空间对偶性

在这里,我们将介绍Mamba 2论文中的一个核心结果。

让我们再次考虑y = Mu,其中u = u(x)是输入的函数,M是一个可分离矩阵。如果我们考虑一个非常特殊的情况,其中每个A_t都是一个标量矩阵:A_t = a_t I。在这种情况下公式变得特别简单:

这里的

只是一个标量。还可以将C_i和B_i堆叠成矩阵B和C,使得:

现在我们还需要定义矩阵

然后就可以很容易地验证:

这个表达式是否看起来很熟悉?这实际上是一个掩码注意力机制,其中:

  • G作为掩码
  • C作为查询矩阵Q
  • B作为转置的键矩阵K^T
  • u作为值矩阵V

在经典的SSM中,B和C是常量。但在Mamba模型中,它们被设计为依赖于数据,这进一步强化了与注意力机制的对应关系。这种特定状态空间模型与掩码注意力之间的对应关系在Mamba 2论文中被称为状态空间对偶性

进一步探索

使用矩阵混合器而不是更复杂的架构并不是一个全新的idea。一个早期的例子是是MLP-Mixer,它在计算机视觉任务中使用MLP而不是卷积或注意力来进行空间混合。

尽管当前研究主要集中在大语言模型(LLM)上,但也有一些论文提出了用于编码器模型的非Transformer、矩阵混合架构。例如:

  1. 来自Google研究的FNet,其矩阵混合器M基于傅里叶变换。
  2. Hydra,除了其他创新外,还提出了半可分离矩阵在非因果(非三角)工作模式下的适应性方案。

总结

本文深入探讨了Transformer、循环神经网络(RNN)和状态空间模型(SSM)之间的潜在联系。文章首先回顾了传统的掩码自注意力机制,然后引入了线性化注意力的概念,解释了其计算效率优势。接着探讨了注意力掩码的优化,引入了半可分离矩阵的概念,并阐述了其与状态空间模型的关系。最后介绍了状态空间对偶性,揭示了特定状态空间模型与掩码注意力之间的对应关系。通过这些分析,展示了看似不同的模型架构之间存在深层联系,为未来模型设计和跨架构思想交流提供了新的视角和可能性。

https://avoid.overfit.cn/post/cc1b1bb7816b412790e9224484cd5b56

作者:Stanislav Fedotov

目录
相关文章
|
2月前
|
机器学习/深度学习 传感器 自然语言处理
基于Transformer架构的时间序列数据去噪技术研究
本文介绍了一种基于Transformer架构的时间序列去噪模型。通过生成合成数据训练,模型在不同噪声条件下展现出强去噪能力。文章详细解析了Transformer的输入嵌入、位置编码、自注意力机制及前馈网络等关键组件,并分析实验结果与注意力权重分布。研究为特定任务的模型优化和专业去噪模型开发奠定了基础。
200 14
基于Transformer架构的时间序列数据去噪技术研究
|
3月前
|
机器学习/深度学习 PyTorch 调度
MiTS与PoTS:面向连续值时间序列的极简Transformer架构
本文探讨了将标准Transformer架构应用于连续值时间序列数据的最小化调整方案,提出了极简时间序列Transformer(MiTS-Transformer)和位置编码扩展时间序列Transformer(PoTS-Transformer)。通过替换嵌入层为线性映射层,MiTS-Transformer实现了对正弦波序列的有效学习。而PoTS-Transformer则通过在高维空间中进行位置编码,结合低维模型架构,解决了长序列处理与过拟合问题。实验结果表明,这两种模型在不同类型的时间序列预测任务中表现出色,为基于Transformer的时间序列预测提供了高效基准方案。
86 5
MiTS与PoTS:面向连续值时间序列的极简Transformer架构
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
深入剖析Transformer架构中的多头注意力机制
多头注意力机制(Multi-Head Attention)是Transformer模型中的核心组件,通过并行运行多个独立的注意力机制,捕捉输入序列中不同子空间的语义关联。每个“头”独立处理Query、Key和Value矩阵,经过缩放点积注意力运算后,所有头的输出被拼接并通过线性层融合,最终生成更全面的表示。多头注意力不仅增强了模型对复杂依赖关系的理解,还在自然语言处理任务如机器翻译和阅读理解中表现出色。通过多头自注意力机制,模型在同一序列内部进行多角度的注意力计算,进一步提升了表达能力和泛化性能。
|
6月前
|
机器学习/深度学习 编解码 人工智能
超越Transformer,全面升级!MIT等华人团队发布通用时序TimeMixer++架构,8项任务全面领先
一支由麻省理工学院、香港科技大学(广州)、浙江大学和格里菲斯大学的华人研究团队,开发了名为TimeMixer++的时间序列分析模型。该模型在8项任务中超越现有技术,通过多尺度时间图像转换、双轴注意力机制和多尺度多分辨率混合等技术,实现了性能的显著提升。论文已发布于arXiv。
523 84
|
5月前
|
机器学习/深度学习 人工智能 NoSQL
记忆层增强的 Transformer 架构:通过可训练键值存储提升 LLM 性能的创新方法
Meta研究团队开发的记忆层技术通过替换Transformer中的前馈网络(FFN),显著提升了大语言模型的性能。记忆层使用可训练的固定键值对,规模达百万级别,仅计算最相似的前k个键值,优化了计算效率。实验显示,记忆层使模型在事实准确性上提升超100%,且在代码生成和通用知识领域表现优异,媲美4倍计算资源训练的传统模型。这一创新对下一代AI架构的发展具有重要意义。
229 11
记忆层增强的 Transformer 架构:通过可训练键值存储提升 LLM 性能的创新方法
|
5月前
|
机器学习/深度学习 人工智能 并行计算
Titans:谷歌新型神经记忆架构,突破 Transformer 长序列处理的瓶颈
Titans 是谷歌推出的新型神经网络架构,通过神经长期记忆模块突破 Transformer 在处理长序列数据时的瓶颈,支持并行计算,显著提升训练效率。
175 5
Titans:谷歌新型神经记忆架构,突破 Transformer 长序列处理的瓶颈
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
AI自己长出了类似大脑的脑叶?新研究揭示LLM特征的惊人几何结构
近年来,大型语言模型(LLM)的内部运作机制备受关注。麻省理工学院的研究人员在论文《The Geometry of Concepts: Sparse Autoencoder Feature Structure》中,利用稀疏自编码器(SAE)分析LLM的激活空间,揭示了其丰富的几何结构。研究发现,特征在原子、大脑和星系三个尺度上展现出不同的结构,包括晶体结构、中尺度模块化结构和大尺度点云结构。这些发现不仅有助于理解LLM的工作原理,还可能对模型优化和其他领域产生重要影响。
151 25
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
RNN回归!Bengio新作大道至简与Transformer一较高下
研究团队提出了一种名为“minimal LSTMs and GRUs”的新型RNN模型,通过简化传统LSTM和GRU结构,去除了隐藏状态对输入、遗忘和更新门的依赖,实现了无需BPTT的高效并行训练。该模型不仅保持了RNN处理序列数据的优势,还大幅提升了训练速度,在多个任务上的表现与Transformer相当,同时减少了参数量。研究结果发表于论文《minimal LSTMs and GRUs》。
118 9
|
7月前
|
人工智能 自然语言处理 测试技术
苹果一篇论文得罪大模型圈?Transformer不会推理,只是高级模式匹配器!所有LLM都判死刑
苹果公司发布论文《GSM-Symbolic: Understanding the Limitations of Mathematical Reasoning in Large Language Models》,质疑大型语言模型(LLM)在数学推理方面的能力。尽管LLM在GSM8K等测试中表现良好,但在新基准测试GSM-Symbolic中,其准确率随数值变化而显著下降,表明LLM可能依赖于记忆和模式匹配而非真正的数学理解。这一发现引发了AI领域的广泛讨论。
121 5
|
7月前
|
机器学习/深度学习 自然语言处理 计算机视觉
探索深度学习中的Transformer架构
探索深度学习中的Transformer架构
151 2