记忆层增强的 Transformer 架构:通过可训练键值存储提升 LLM 性能的创新方法

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,5000CU*H 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
简介: Meta研究团队开发的记忆层技术通过替换Transformer中的前馈网络(FFN),显著提升了大语言模型的性能。记忆层使用可训练的固定键值对,规模达百万级别,仅计算最相似的前k个键值,优化了计算效率。实验显示,记忆层使模型在事实准确性上提升超100%,且在代码生成和通用知识领域表现优异,媲美4倍计算资源训练的传统模型。这一创新对下一代AI架构的发展具有重要意义。

大语言模型(LLM)通过其参数储存了大量信息,这些信息主要以密集层中线性矩阵变换的权重形式存在。然而,参数规模的扩大必然导致计算成本和能源消耗的显著增加。

这种参数存储方式是否可以通过更高效的键值查找机制来优化?

尽管此前已有多项相关研究,但在当前 AI 架构规模下的实践尚属首次。

Meta 研究团队通过开发记忆层技术,成功实现了对现有大语言模型的性能提升。该技术通过替换一个或多个 Transformer 层中的前馈网络(FFN)来实现功能。

实验数据显示,记忆层的引入使大语言模型在事实准确性方面提升了 100% 以上。同时其在代码生成和通用知识领域的表现可与使用 4 倍计算资源训练的传统大语言模型相媲美。

在事实性任务评估中,搭载记忆层的大语言模型的性能明显优于在相似计算资源和参数规模条件下训练的专家混合型(Mixture-of-experts)架构。

本文将深入探讨记忆层的技术原理及其对大语言模型性能的提升机制,这一技术创新对下一代 AI 架构的发展具有重要意义。

记忆层的技术原理

我们先看一下Transformer的基本机构

记忆层在功能实现上与 Transformer 的注意力机制有相似之处。基本原理是:给定查询(

Q

)、键(

K

)和值(

V

),通过 softmax 函数计算查询与键之间的相似度,并据此对值(

V

)进行加权求和。

记忆层与传统注意力机制的主要区别在于两个方面:

首先,传统注意力机制中的键和值是针对每个查询动态计算的,而记忆层中的键和值是可训练的固定参数。这意味着这些参数通过训练获得并持久保存。

其次,记忆层所使用的键值对规模达到百万级别

系统仅使用与查询最相似的前 k 个键及其对应值来计算输出,这种方法显著提高了大规模运算时的计算效率。记忆层的数学表达可以通过以下方程系统来描述:

首先,基于查询与键之间的相似度计算确定前 k 个键的索引(

I

):

其中 q 表示查询向量,K 表示可训练的键矩阵。

随后,计算选定键的相似度得分(

K(I)q

),并通过 Softmax 函数进行归一化,得到权重向量(

s

):

其中 q 表示查询向量,K(I) 表示已选择的前 k 个键矩阵。

最后,利用前 k 个值的加权和计算输出向量(

y

):

其中 s 表示经过 softmax 归一化的权重向量,V(I) 表示选定的前 k 个值矩阵。在记忆层中,每个词元嵌入都进行独立的处理,这一点与传统 Transformer 中的前馈层处理方式相同。

大规模相似键搜索的优化策略

在大规模场景下查找最相似键值的计算开销较大。传统的最近邻搜索算法流程如下:

  • 计算查询向量与所有键之间的相似度(如余弦相似度),时间复杂度为 O(N ⋅ n),其中 N 为键的数量,n 为向量维度
  • 对相似度进行排序,时间复杂度为 O(N log(N))
  • 选取相似度最高的前 k 个键
  • 利用这 k 个键计算最终输出

该方法的空间复杂度为

O(N ⋅ n)

,在处理百万级键值对时计算资源消耗过大。近似最近邻(ANN)搜索同样不适用于此场景,因为 ANN 需要预先构建静态索引。由于记忆层中的键是可训练参数且在训练过程中持续更新,这就要求不断重建索引。

那么,是否存在更优的解决方案?

研究团队采用了一种源自先前研究的可训练乘积量化键技术,下面将详细说明其实现原理。

键矩阵分解策略

该方法不直接使用完整的键矩阵(

K

),而是将其分解为两个较小的矩阵(

K(1)

K(2)

)。

原始键矩阵的维度为

N × n

,分解后的两个子矩阵维度均为

√N × n/2

,其中

N

表示键的总数,

n

表示向量维度。

完整的键矩阵可以通过这两个子矩阵的笛卡尔积表示:

_K = K(1) X K(2)_

这种设计避免了显式构建完整矩阵,从而实现了计算资源的优化。

查询向量分解

与键矩阵分解相对应,查询向量(

Q

)也被分解为两个子向量(

Q(1)

Q(2)

)。原始查询向量的维度为

n

,分解后的子向量维度各为

n/2

。这两个子向量分别与对应的键子矩阵进行运算。

相似键的检索与相似度计算

对于

Q(1)

,系统在

K(1)

中检索前 k 个相似键,得到索引集合(

I(1)

)。随后通过 Softmax 函数计算相似度得分(

s(1)

)。

Q(2)

K(2)

之间进行相同的操作。

全局最优解的获取

通过对索引和得分应用 Argmax 函数,可以得到全局最优的前 k 个索引和对应得分:

这种方法的优势在于:

将查询与所有

N

个键的直接比较转化为与两个较小集合的比较,使得时间和空间复杂度从

O(N ⋅ n)

降低到

O(√N ⋅ n)

,大幅提升了计算效率。

GPU 并行计算的实现

记忆层包含数百万个可训练参数(键和值矩阵)。为了高效处理这些参数,系统采用了以下并行计算策略:

  1. 将参数沿嵌入维度分片
  2. 在多个 GPU 上分布式存储
  3. 每个 GPU 负责管理其分配到的参数分片
  4. 通过进程组协调各 GPU 之间的运算

查询操作的执行流程如下:

  1. 识别并分发相关索引至各 GPU
  2. 各 GPU 在其负责的分片中检索对应嵌入
  3. 收集并整合各 GPU 的部分结果,得到最终输出

GPU 并行化的记忆层运算示意图

GPU 计算效率优化

PyTorch 提供的

EmbeddingBag

函数可用于计算记忆层中前 k 个嵌入的加权和。然而,其默认实现在 GPU 内存带宽利用率方面存在局限。

测试显示,默认实现的内存带宽利用率不足 400 GB/s,远未充分发挥现代 GPU 的性能潜力。为此研究团队开发了专门的 CUDA 内核,用于优化前向和反向传播的计算效率。

优化后的实现达到了 3 TB/s 的内存带宽,接近 NVIDIA H100 GPU 3.35 TB/s 的理论峰值,使得嵌入运算的端到端性能提升了约 6 倍。

此外通过引入基于 SiLU 非线性函数的输入依赖门控机制,进一步提升了记忆层的训练性能。

优化后的输出计算公式如下:

其中各参数定义如下:

  • silu(x) = x ∗ σ(x),σ(x) 为 sigmoid 函数
  • 表示 Hadamard 积(逐元素乘法)
  • x 为记忆层输入
  • y 为经门控机制调制后的输出
  • W(1)W(2) 为可训练权重矩阵

上图为标准记忆层与引入输入依赖门控机制后的记忆层性能对比

在实践中发现,当小规模基础模型与大规模记忆层结合时,可能出现训练不稳定的问题。为解决这一问题,引入了 QK 归一化技术。该技术在计算点积之前对查询向量(

Q

)和键向量(

K

)进行归一化处理。

记忆层的最优配置策略

在深度神经网络中,浅层网络主要学习基础特征,而深层网络则负责提取复杂模式。实验表明,在多个层次中引入记忆层可以获得最佳效果。为了控制参数规模,在所有层间采用了共享内存池机制。这种设计使得多个层可以共享访问同一内存资源,提高了架构效率。

实验数据显示,在不超过 3 个层中使用记忆层可以持续提升模型性能,但过度替换密集前馈网络(FFN)层会导致性能下降。

这一现象表明,稀疏记忆层与密集前馈层各有其独特优势,最佳方案是将两者结合使用。

记忆层增强型大语言模型的性能评估

研究团队选择 Llama 系列模型(Llama2 和 Llama3)作为基准,将其一个或多个前馈层(FFN)替换为共享记忆层进行实验。

实验设置包括两种配置:基础记忆模型(使用单一记忆层)和增强型记忆模型("Memory +",使用三个记忆层并集成 Swilu 非线性函数)。

SwiLU 非线性函数:其中 β 为可学习参数,σ(x) 为 sigmoid 函数

为进行对比分析,研究还包含了配置相当的专家混合模型(MoE,采用专家选择路由训练)和 PEER 模型。

实验结果分析

在问答(QA)任务评估中,记忆模型展现出显著优势:其性能超过了参数规模相当的密集模型,达到了参数数量两倍的密集模型的水平。

记忆增强型架构与基准模型在问答任务上的性能对比

增强型记忆模型("Memory +")的表现更为突出,其性能可与使用 2-4 倍计算资源训练的密集模型相匹敌。

各架构在问答任务中的准确率比较("Memory +"模型配置:100万记忆嵌入)

值得注意的是,PEER 模型在相同参数规模下的表现与基础记忆模型相当,但未能达到增强型记忆模型的水平。

同时,专家混合模型的性能显著低于记忆增强型模型。在固定基础模型参数的情况下扩展记忆参数规模时,模型在事实性问答任务上表现出显著的性能提升。

实验显示,配置 6400 万个键的 1.3B 参数记忆模型,仅使用 1/10 的计算量和一半的训练数据量,即可达到 Llama2 7B 模型的性能水平。

图表展示了 1.3B 参数模型在 NaturalQuestions(NQ)和 TriviaQA(TQA)基准测试中的性能指标:随着记忆规模扩大,事实性问答准确率提升,负对数似然(NLL)降低。虚线表示使用 10 倍计算资源、在 2 万亿词元上训练的 7B 模型的性能水平。

在 8B 规模模型的评估中,记忆模型在科学知识、通用知识和编程能力等基准测试上的表现明显优于传统密集模型。

特别值得一提的是,经过 1 万亿词元的训练,增强型记忆模型("Memory +")的性能已接近在 15 万亿词元(15 倍数据量)上训练的 Llama3.1 8B 模型。

总结

实验结果表明,记忆层技术在提升大语言模型性能方面具有显著优势。随着大语言模型逐渐接近计算资源和物理极限,这项技术的应用价值将愈发凸显。

论文:https://avoid.overfit.cn/post/bc94fb7278ff425f8af5ffa053a5ab12

作者:Dr. Ashish Bamania

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
目录
相关文章
|
11天前
|
机器学习/深度学习 人工智能 计算机视觉
MILS:无需对LLM进行额外训练就能处理多模态任务,Meta AI提出零样本生成多模态描述方法
MILS 是 Meta AI 推出的零样本生成高质量多模态描述方法,支持图像、视频和音频的描述生成,无需额外训练。
102 34
MILS:无需对LLM进行额外训练就能处理多模态任务,Meta AI提出零样本生成多模态描述方法
|
30天前
|
机器学习/深度学习 自然语言处理 PyTorch
深入剖析Transformer架构中的多头注意力机制
多头注意力机制(Multi-Head Attention)是Transformer模型中的核心组件,通过并行运行多个独立的注意力机制,捕捉输入序列中不同子空间的语义关联。每个“头”独立处理Query、Key和Value矩阵,经过缩放点积注意力运算后,所有头的输出被拼接并通过线性层融合,最终生成更全面的表示。多头注意力不仅增强了模型对复杂依赖关系的理解,还在自然语言处理任务如机器翻译和阅读理解中表现出色。通过多头自注意力机制,模型在同一序列内部进行多角度的注意力计算,进一步提升了表达能力和泛化性能。
|
23天前
|
自然语言处理 算法 JavaScript
面向长文本的多模型协作摘要架构:多LLM文本摘要方法
多LLM摘要框架通过生成和评估两个步骤处理长文档,支持集中式和分散式两种策略。每个LLM独立生成文本摘要,集中式方法由单一LLM评估并选择最佳摘要,而分散式方法则由多个LLM共同评估,达成共识。论文提出两阶段流程:先分块摘要,再汇总生成最终摘要。实验结果显示,多LLM框架显著优于单LLM基准,性能提升最高达3倍,且仅需少量LLM和一轮生成评估即可获得显著效果。
58 10
面向长文本的多模型协作摘要架构:多LLM文本摘要方法
|
27天前
|
机器学习/深度学习 人工智能 并行计算
Titans:谷歌新型神经记忆架构,突破 Transformer 长序列处理的瓶颈
Titans 是谷歌推出的新型神经网络架构,通过神经长期记忆模块突破 Transformer 在处理长序列数据时的瓶颈,支持并行计算,显著提升训练效率。
93 5
Titans:谷歌新型神经记忆架构,突破 Transformer 长序列处理的瓶颈
|
2月前
|
弹性计算 API 持续交付
后端服务架构的微服务化转型
本文旨在探讨后端服务从单体架构向微服务架构转型的过程,分析微服务架构的优势和面临的挑战。文章首先介绍单体架构的局限性,然后详细阐述微服务架构的核心概念及其在现代软件开发中的应用。通过对比两种架构,指出微服务化转型的必要性和实施策略。最后,讨论了微服务架构实施过程中可能遇到的问题及解决方案。
|
3月前
|
Cloud Native Devops 云计算
云计算的未来:云原生架构与微服务的革命####
【10月更文挑战第21天】 随着企业数字化转型的加速,云原生技术正迅速成为IT行业的新宠。本文深入探讨了云原生架构的核心理念、关键技术如容器化和微服务的优势,以及如何通过这些技术实现高效、灵活且可扩展的现代应用开发。我们将揭示云原生如何重塑软件开发流程,提升业务敏捷性,并探索其对企业IT架构的深远影响。 ####
83 3
|
3月前
|
Cloud Native 安全 数据安全/隐私保护
云原生架构下的微服务治理与挑战####
随着云计算技术的飞速发展,云原生架构以其高效、灵活、可扩展的特性成为现代企业IT架构的首选。本文聚焦于云原生环境下的微服务治理问题,探讨其在促进业务敏捷性的同时所面临的挑战及应对策略。通过分析微服务拆分、服务间通信、故障隔离与恢复等关键环节,本文旨在为读者提供一个关于如何在云原生环境中有效实施微服务治理的全面视角,助力企业在数字化转型的道路上稳健前行。 ####
|
2月前
|
Java 开发者 微服务
从单体到微服务:如何借助 Spring Cloud 实现架构转型
**Spring Cloud** 是一套基于 Spring 框架的**微服务架构解决方案**,它提供了一系列的工具和组件,帮助开发者快速构建分布式系统,尤其是微服务架构。
299 69
从单体到微服务:如何借助 Spring Cloud 实现架构转型
|
9天前
|
传感器 监控 安全
智慧工地云平台的技术架构解析:微服务+Spring Cloud如何支撑海量数据?
慧工地解决方案依托AI、物联网和BIM技术,实现对施工现场的全方位、立体化管理。通过规范施工、减少安全隐患、节省人力、降低运营成本,提升工地管理的安全性、效率和精益度。该方案适用于大型建筑、基础设施、房地产开发等场景,具备微服务架构、大数据与AI分析、物联网设备联网、多端协同等创新点,推动建筑行业向数字化、智能化转型。未来将融合5G、区块链等技术,助力智慧城市建设。
|
3月前
|
Dubbo Java 应用服务中间件
服务架构的演进:从单体到微服务的探索之旅
随着企业业务的不断拓展和复杂度的提升,对软件系统架构的要求也日益严苛。传统的架构模式在应对现代业务场景时逐渐暴露出诸多局限性,于是服务架构开启了持续演变之路。从单体架构的简易便捷,到分布式架构的模块化解耦,再到微服务架构的精细化管理,企业对技术的选择变得至关重要,尤其是 Spring Cloud 和 Dubbo 等微服务技术的对比和应用,直接影响着项目的成败。 本篇文章会从服务架构的演进开始分析,探索从单体项目到微服务项目的演变过程。然后也会对目前常见的微服务技术进行对比,找到目前市面上所常用的技术给大家进行讲解。
98 1
服务架构的演进:从单体到微服务的探索之旅