Differential Transformer: 通过差分注意力机制提升大语言模型性能

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
检索分析服务 Elasticsearch 版,2核4GB开发者规格 1个月
简介: 《Differential Transformer》论文提出了一种新的差分注意力机制,旨在解决传统Transformer模型过分关注不相关信息的问题。该机制通过计算两个独立的注意力图谱之差来消除注意力噪声,提高模型性能。实验结果显示,DIFF Transformer在减少参数量和训练token数量的同时,显著提升了多目标检索任务的准确率。

Transformer模型已经成为大语言模型(LLMs)的标准架构,但研究表明这些模型在准确检索关键信息方面仍面临挑战。今天介绍一篇名叫Differential Transformer的论文,论文的作者观察到一个关键问题:传统Transformer模型倾向于过分关注不相关的上下文信息,这种"注意力噪声"会影响模型的性能。

在这篇论文中,作者注意到transformer模型倾向于关注不相关的上下文。为了放大相关上下文的注意力分数,他们提出了一个新的注意力模型,称为差分注意力模型。在这个模型中,他们将查询和键值向量分成两组,并计算两个子注意力分数。

差分注意力机制

差分注意力机制(Differential Attention)的核心思想是通过计算两个独立的注意力图谱之差来消除注意力噪声。这种设计借鉴了电气工程中差分放大器的原理,通过对比两个信号的差异来消除共模噪声。

让我们看看论文中的第一个方程:

方程(1)

方程(1)显示,我们首先像标准注意力计算一样计算Q、K和V张量。关键点是我们将Q和K张量分成Q1、Q2和K1、K2子张量。

论文中输入X、Q1、Q2、K1、K2和V张量的形状

根据论文,Q和K张量的形状应该是Nx2d,因为Q1、Q2、K1和K2将是Nxd。输入X的形状是Nxd_model,这是论文中的嵌入维度。这就是为什么W_Q、W_K和W_V的可学习参数的形状必须是d_modelx2d。

论文中用于lambda计算的方程(2)

方程(2)展示了如何计算可学习参数lambda。在这个方程中有一个初始lambda参数。lambda是一个标量参数,但lambda_q1、lambda_k1、lambda_q2和lambda_k2是向量。这一点很关键。向量lambda_q和lambda_k的运算是点积。

用于lambda初始化的方程(3)

实验结果与性能提升

论文的实验表明,相比传统Transformer:

DIFF Transformer只需要约65%的模型参数量即可达到相同的性能,在训练token数量方面也只需要约65%就能达到相同效果

在Needle-In-A-Haystack测试中:4K上下文长度:DIFF Transformer在多目标检索任务中保持85%准确率;64K上下文长度:在深度为25%的位置检测时,比传统Transformer提升了76%的准确率

Python实现

下面我们根据论文的公式来做一个简单的实现,首先方程(3)展示了我们如何计算lambda_initial变量。现在让我们把方程转换成Python代码:

 deflambda_init_fn(depth):  
     return0.8-0.6*math.exp(-0.3*depth)

然后再写一个简单的Python函数,使用方程(3)。

 classDifferentialAttention(nn.Module):  
     def__init__(self, dim_model: int, head_nums: int, depth: int):  
         super().__init__()  

         self.head_dim=dim_model//head_nums  

         self.Q=nn.Linear(dim_model, 2*self.head_dim, bias=False)  
         self.K=nn.Linear(dim_model, 2*self.head_dim, bias=False)  
         self.V=nn.Linear(dim_model, 2*self.head_dim, bias=False)  
         self.scale=self.head_dim**-0.5  
         self.depth=depth  
         self.lambda_q1=nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))  
         self.lambda_q2=nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))  
         self.lambda_k1=nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))  
         self.lambda_k2=nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))  
         self.rotary_emb=RotaryEmbedding(self.head_dim*2)

在DifferentialAttention类中,我们实现了一个多头差分注意力机制。有dim_model(嵌入维度)、head_nums和depth参数。为Q1、Q2、K1和K2声明了四个lambda可学习参数,并使用均值为0、标准差为0.1的随机正态分布初始化它们。

     defforward(self, x):  
         lambda_init=lambda_init_fn(self.depth)  
         Q=self.Q(x)  
         K=self.K(x)  

         seq_len=x.shape[1]  
         cos, sin=self.rotary_emb(seq_len, device=x.device)  
         Q, K=apply_rotary_pos_emb(Q, K, cos, sin)  

         Q1, Q2=Q.chunk(2, dim=-1)  
         K1, K2=K.chunk(2, dim=-1)  
         V=self.V(x)  
         A1=Q1@K1.transpose(-2, -1) *self.scale  
         A2=Q2@K2.transpose(-2, -1) *self.scale  
         lambda_1=torch.exp(torch.sum(self.lambda_q1*self.lambda_k1, dim=-1).float()).type_as(Q1)  
         lambda_2=torch.exp(torch.sum(self.lambda_q2*self.lambda_k2, dim=-1).float()).type_as(Q2)  
         lambda_=lambda_1-lambda_2+lambda_init  
         return (F.softmax(A1, dim=-1)  -lambda_*F.softmax(A2, dim=-1)) @V

forward方法很直观。我分别实现了方程(1)和方程(2)。forward方法直接实现了论文中的伪代码。

多头差分注意力架构和伪代码

 classMultiHeadDifferentialAttention(nn.Module):  
     def__init__(self, dim_model: int, head_nums: int, depth: int):  
         super().__init__()  
         self.heads=nn.ModuleList([DifferentialAttention(dim_model, head_nums, depth) for_inrange(head_nums)])  
         self.group_norm=RMSNorm(dim_model)  
         self.output=nn.Linear(2*dim_model, dim_model, bias=False)  
         self.lambda_init=lambda_init_fn(depth)  

     defforward(self, x):  
         o=torch.cat([self.group_norm(h(x)) forhinself.heads], dim=-1)  
         o=o* (1-self.lambda_init)  
         returnself.output(o)

MultiHeadDifferentialAttention类是根据论文中的伪代码编写的。这里使用了RMSNorm而不是GroupNorm。

论文中使用多头差分注意力机制的语言模型的方程

最后使用实现的MultiHeadDifferentialAttention机制构建一个transformer解码器。

 classDifferentialTransformer(nn.Module):  
     def__init__(self, dim: int, depth: int, heads: int=8, head_dim: int=64, vocab_size: int=10000):  
         super().__init__()  
         self.vocab_size=vocab_size  
         self.layers=nn.ModuleList([  
             MultiHeadDifferentialAttention(dim, heads, depth_idx)  
             fordepth_idxinrange(depth)  
         ])  
         self.ln1=RMSNorm(dim)  
         self.ln2=RMSNorm(dim)  
         self.ffn=FeedForward(dim, (dim//3) *8)  
         self.output=nn.Linear(dim, self.vocab_size)  

     defforward(self, x):  
         forattninself.layers:  
             y=attn(self.ln1(x)) +x  
             x=self.ffn(self.ln2(y)) +y  
         returnself.output(x)

性能优化

论文还提供了两种FlashAttention实现方式:

1、支持不同维度的实现:

 def FlashDiffAttn_1(X, W_q, W_k, W_v, λ):
     Q1, Q2 = split(X @ W_q)
     K1, K2 = split(X @ W_k)
     V = X @ W_v
     A1 = flash_attn(Q1, K1, V)
     A2 = flash_attn(Q2, K2, V)
     return A1 - λ A2

固定维度的实现:

 def FlashDiffAttn_2(X, W_q, W_k, W_v, λ):
     Q1, Q2 = split(X @ W_q)
     K1, K2 = split(X @ W_k)
     V1, V2 = split(X @ W_v)
     A11 = flash_attn(Q1, K1, V1)
     A12 = flash_attn(Q1, K1, V2)
     A1 = Concat(A11, A12)
     A21 = flash_attn(Q2, K2, V1)
     A22 = flash_attn(Q2, K2, V2)
     A2 = Concat(A21, A22)
     return A1 - λ A2

Differential Transformer论文提出的两种FlashAttention实现方案各有特色。第一种实现(FlashDiffAttn_1)采用直接计算策略,允许Q、K、V具有不同的维度,这种灵活性使其更适合需要动态调整维度的场景,但可能在某些硬件上的优化效果不如第二种方案。第二种实现(FlashDiffAttn_2)通过将计算分解为多个相同维度的子运算,虽然计算步骤增多,但每个步骤都能充分利用硬件优化,特别是在支持张量核心的现代GPU上表现更好。

这两种实现的选择主要取决于具体应用场景:如果模型架构需要频繁调整维度或者需要更灵活的注意力机制,建议使用第一种实现;如果追求极致的计算效率且维度相对固定,第二种实现可能是更好的选择。从工程实践角度看,第二种实现与现有的FlashAttention优化库的兼容性更好,更容易在现有基础设施上部署和优化。

局限性和未来研究方向

Differential Transformer虽然在多个方面展现出优秀的性能,但仍然存在一些值得关注的局限性。首要的挑战来自其计算效率方面。由于模型需要同时计算两个独立的注意力图谱,这不可避免地增加了计算开销。在实际测试中,相比传统Transformer,DIFF Transformer在3B规模模型上的计算吞吐量降低了约9%,这种性能损失虽然可以通过更少的参数量来部分抵消,但在大规模部署场景中仍然需要认真考虑。

内存使用是另一个重要的局限性。模型需要存储两组独立的查询和键值向量,这导致了更高的内存占用。尽管这种设计对于提升模型性能是必要的,但在资源受限的环境下可能会造成部署困难。特别是在处理超长序列时,内存压力会进一步加大。

训练稳定性也是一个需要特别关注的问题。模型中λ参数的初始化策略对训练过程的稳定性有显著影响。研究发现,不同的λinit取值会导致训练收敛速度和最终性能的差异。虽然论文提出了一个基于层深度的初始化策略,但这种方案并非在所有场景下都能取得最优效果,有时需要根据具体任务进行调整。

基于这些局限性,论文提出未来的研究可以沿着几个重要方向展开。首先在计算效率优化方面,可以探索更高效的注意力核心实现。这包括研究如何更好地利用现代硬件特性,例如开发专门的CUDA核心来加速差分注意力的计算。同时考虑到模型产生的稀疏注意力模式,可以设计特定的稀疏计算优化策略,这不仅能提升计算效率,还能减少内存占用。

λ参数的动态调整机制是另一个值得深入研究的方向。当前的参数计算方案虽然有效,但仍有优化空间。可以考虑设计更灵活的自适应机制,使λ参数能够根据输入内容和任务特点动态调整,从而在不同场景下都能获得最佳性能。这可能需要引入额外的上下文感知机制,或者设计新的参数更新策略。

在内存优化方面,量化技术提供了一个有前景的研究方向。考虑到DIFF Transformer在处理激活值异常方面的优势,可以探索专门的量化策略。比如,研究如何在保持模型性能的同时,对注意力权重和中间状态进行更激进的量化,从而减少内存占用。这对于模型在边缘设备上的部署具有重要意义。

长文本建模能力的进一步提升也是一个重要研究方向。虽然当前模型在64K长度的实验中表现出色,但随着应用需求的增长,可能需要处理更长的序列。这要求研究如何在更长序列上保持模型的效率和性能,可能需要开发新的注意力机制变体或优化策略。

总结

DIFF Transformer通过创新的差分注意力机制成功提升了模型性能,特别是在长文本理解、关键信息检索和模型鲁棒性等方面。虽然存在一些计算效率和内存使用的权衡,但考虑到显著的性能提升和更少的参数需求,这是一个非常有价值的改进。这项工作为大语言模型的架构设计提供了新的思路,也为后续研究指明了几个重要的优化方向。

论文地址:

https://avoid.overfit.cn/post/f2e9e7856db24002beb7fc7d2dc33c96

目录
相关文章
|
4天前
|
存储 运维 安全
云上金融量化策略回测方案与最佳实践
2024年11月29日,阿里云在上海举办金融量化策略回测Workshop,汇聚多位行业专家,围绕量化投资的最佳实践、数据隐私安全、量化策略回测方案等议题进行深入探讨。活动特别设计了动手实践环节,帮助参会者亲身体验阿里云产品功能,涵盖EHPC量化回测和Argo Workflows量化回测两大主题,旨在提升量化投研效率与安全性。
云上金融量化策略回测方案与最佳实践
|
6天前
|
人工智能 自然语言处理 前端开发
从0开始打造一款APP:前端+搭建本机服务,定制暖冬卫衣先到先得
通义灵码携手科技博主@玺哥超carry 打造全网第一个完整的、面向普通人的自然语言编程教程。完全使用 AI,再配合简单易懂的方法,只要你会打字,就能真正做出一个完整的应用。
6025 18
|
18天前
|
人工智能 自动驾驶 大数据
预告 | 阿里云邀您参加2024中国生成式AI大会上海站,马上报名
大会以“智能跃进 创造无限”为主题,设置主会场峰会、分会场研讨会及展览区,聚焦大模型、AI Infra等热点议题。阿里云智算集群产品解决方案负责人丛培岩将出席并发表《高性能智算集群设计思考与实践》主题演讲。观众报名现已开放。
|
10天前
|
自然语言处理 数据可视化 API
Qwen系列模型+GraphRAG/LightRAG/Kotaemon从0开始构建中医方剂大模型知识图谱问答
本文详细记录了作者在短时间内尝试构建中医药知识图谱的过程,涵盖了GraphRAG、LightRAG和Kotaemon三种图RAG架构的对比与应用。通过实际操作,作者不仅展示了如何利用这些工具构建知识图谱,还指出了每种工具的优势和局限性。尽管初步构建的知识图谱在数据处理、实体识别和关系抽取等方面存在不足,但为后续的优化和改进提供了宝贵的经验和方向。此外,文章强调了知识图谱构建不仅仅是技术问题,还需要深入整合领域知识和满足用户需求,体现了跨学科合作的重要性。
|
6天前
|
人工智能 容器
三句话开发一个刮刮乐小游戏!暖ta一整个冬天!
本文介绍了如何利用千问开发一款情侣刮刮乐小游戏,通过三步简单指令实现从单个功能到整体框架,再到多端优化的过程,旨在为生活增添乐趣,促进情感交流。在线体验地址已提供,鼓励读者动手尝试,探索编程与AI结合的无限可能。
|
1月前
|
存储 人工智能 弹性计算
阿里云弹性计算_加速计算专场精华概览 | 2024云栖大会回顾
2024年9月19-21日,2024云栖大会在杭州云栖小镇举行,阿里云智能集团资深技术专家、异构计算产品技术负责人王超等多位产品、技术专家,共同带来了题为《AI Infra的前沿技术与应用实践》的专场session。本次专场重点介绍了阿里云AI Infra 产品架构与技术能力,及用户如何使用阿里云灵骏产品进行AI大模型开发、训练和应用。围绕当下大模型训练和推理的技术难点,专家们分享了如何在阿里云上实现稳定、高效、经济的大模型训练,并通过多个客户案例展示了云上大模型训练的显著优势。
|
10天前
|
Cloud Native Apache 流计算
PPT合集|Flink Forward Asia 2024 上海站
Apache Flink 年度技术盛会聚焦“回顾过去,展望未来”,涵盖流式湖仓、流批一体、Data+AI 等八大核心议题,近百家厂商参与,深入探讨前沿技术发展。小松鼠为大家整理了 FFA 2024 演讲 PPT ,可在线阅读和下载。
3543 10
PPT合集|Flink Forward Asia 2024 上海站
|
3天前
|
弹性计算 运维 监控
阿里云云服务诊断工具:合作伙伴架构师的深度洞察与优化建议
作为阿里云的合作伙伴架构师,我深入体验了其云服务诊断工具,该工具通过实时监控与历史趋势分析,自动化检查并提供详细的诊断报告,极大提升了运维效率和系统稳定性,特别在处理ECS实例资源不可用等问题时表现突出。此外,它支持预防性维护,帮助识别潜在问题,减少业务中断。尽管如此,仍建议增强诊断效能、扩大云产品覆盖范围、提供自定义诊断选项、加强教育与培训资源、集成第三方工具,以进一步提升用户体验。
612 242
|
23天前
|
人工智能 自然语言处理 前端开发
100个降噪蓝牙耳机免费领,用通义灵码从 0 开始打造一个完整APP
打开手机,录制下你完成的代码效果,发布到你的社交媒体,前 100 个@玺哥超Carry、@通义灵码的粉丝,可以免费获得一个降噪蓝牙耳机。
5954 16
|
5天前
|
消息中间件 人工智能 运维
12月更文特别场——寻找用云高手,分享云&AI实践
我们寻找你,用云高手,欢迎分享你的真知灼见!
506 37