Attention优化重大突破!显存减半效率倍增

简介: 本文探讨了Transformer中Attention机制的演变与优化。从2017年Transformer提出以来,各种改进如MQA、GQA、MLA等层出不穷,旨在降低计算复杂度和显存消耗,同时保持模型性能。文章首先介绍了Attention的基本原理,通过QKV矩阵运算实现序列建模。接着分析了优化方法:kv caching将计算复杂度从O(n^3)降至O(n^2),但带来显存压力;MQA、GQA等通过减少或压缩K/V降低显存需求;而NSV、MoBA等稀疏化研究进一步缓解长序列下的计算与存储负担,推动大模型向更长上下文扩展。

一,前言

从Transformer 17年被提出来, Attention 层的相关的工作从未间断,参考邱锡鹏教授21年关于Transformer的综述里的整理如下:
image.png
更耳熟能详一些的如:MQA,GQA ,DeepSeek的MLA,阶跃星辰的MFA,以及最近DeepSeek提出的NSA和Kimi提出的MoBA。
到底这么多的Attention为什么被提出,它们都意图解决什么问题。 我们用接下来的两个章节来解释一下。

二,Attention 是什么

为了更容易理解我们接下来的内容,我们先来看一下Attention的计算公式:
image.png
这个公式到底在做什么,flashattention的论文里说的比较清楚:
image.png
结合下图更有利于我们深入的理解attention(11. Self Attention - by Tom Yeh - AI by Hand ✍️)
image.png
下面会对以上的这张图做分块的解释:

语言模型简单来说就是针对输入的序列计算概率,推测下一个token是什么,新生成的token以及历史上所有的token又做为下一轮生成的输入。 上面这个图是模型多层transformer中的一层的计算过程的具象化展示,主要含有Attention和FFN(MLP)这2个子层。下面几个小节主要解释attention的计算部分。

1, 我们假设该图是transformer第一层,每个token是一个词,这样可以简化讨论。假设有4个词做为输入,被转化为4个向量,也就是图里的Features:x1,x2,x3,x4。 通过和模型的权重矩阵Wq,Wk 相乘,得到一个Q矩阵,也就是图里的q1,q2,q3,q4 这4个销向量。得到一个K矩阵也就是k1,k2,k3,k4 这4个销向量;
image.png
2, Q,K 通过矩阵乘法计算得到一个矩阵S
image.png
3,然后通过softmax 等计算得到矩阵P
image.png
4,输入x1,x2,x3,x4通过和模型的权重矩阵Wv相乘得到v1,v2,v3,v4 也就是矩阵V。 最后矩阵P和V相乘得到最终的输出向量z1,z2,z3,z4,也就是矩阵O
image.png

三,为什么能做和要做前言里提到的优化工作

从第二章节不难看出(Decoder-only的attention计算和上面略有不同,不影响计算复杂度讨论),如果不做任何优化,生成每一个token的计算复杂度是O(n^2),最终生成的序列全局计算复杂度是O(n^3)。对于上下文这个计算复杂度肯定是无法接受的。
1,所以,直觉上提升Attention的性能的做法是降低它的计算复杂度.
kv caching就是为了解决这个问题将单个token的计算复杂度降低到O(n) (n为当前序列长度),全局的复杂度就下降到O(n^2)极大的提升了性能。比如下图开不开起kv caching 作者观察到的。 kv caching 是一个空间换时间的做法,那么我们需要付出的就是更多的显存空间用来做k,v 的缓存。
image.png
kv caching 为什么可行参考下面2个图(https://x.com/_avichawla/status/1890288542322221206):
image.png
大体来说就是Decoder 里的attention计算由于当前token只参考之前的token而不向后参考,所以如第二章节的历史上的z向量可以在新token生成时不更新,Zk = (Qk K) V。

2,kv caching 极大的提升了性能,但是需要大量的cache 空间。比如上图所示看Llama3 70B,4k token需要10.5 GB 的cache空间,这个是极大的开销。为了降低cache的压力,开头提到的MQA,GQA等被提出,通过降低k v 的数量来降低对显存的消耗。同时模型能力没有出现明显的下降,所以得到了广泛的使用。DeepSeek 的MLA 则是通过压缩k v 大小的方式,进一步降低了cache 的压力,这里的风险是信息的损失会导致模型性能的下降,但从实际的结果来看,这个风险目前被处理得很好。
image.png
3,即使做了1,2的优化,问题就不存在了吗?当然不是,对于更长的上下文cache的压力还是很大,计算压力还是O(n^2) (n 为序列长度),这样模型实际是很难继续朝着长上下文scaling的。2和其他的一些研究表attention是稀疏的,序列越长attention越稀疏,像DeepSeek 的NSV,Kimi的MoBA以及别的一些研究就诞生了,通过这些研究我们可以进一步降低计算量和cache大小同时不降低模型性能。
image.png

大模型私有化部署,点这里https://c.aiiz.cn/Uny3qz

参考资料:

https://arxiv.org/pdf/1706.03762

https://arxiv.org/pdf/2106.04554

https://arxiv.org/pdf/2205.14135

https://arxiv.org/pdf/2502.11089

  1. Self Attention - by Tom Yeh - AI by Hand ✍️

The Annotated Transformer

https://x.com/_avichawla/status/1890288542322221206

相关文章
|
机器学习/深度学习 人工智能 负载均衡
基于 NVIDIA Megatron-Core 的 MoE LLM 实现和训练优化
本文将分享阿里云人工智能平台 PAI 团队与 NVIDIA Megatron-Core 团队在 MoE (Mixture of Experts) 大型语言模型(LLM)实现与训练优化上的创新工作。
|
6月前
|
负载均衡 NoSQL Redis
不增加 GPU,首 Token 延迟下降50%|LLM 服务负载均衡的新实践
针对LLM服务的特点,Higress AI网关以插件形式提供了面向LLM服务的负载均衡算法,包括全局最小请求数负载均衡、前缀匹配负载均衡以及GPU感知负载均衡,能够在不增加硬件成本的前提下,提升系统的吞吐能力、降低响应延迟,并实现更公平、高效的任务调度。
680 135
|
9月前
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
1040 17
|
7月前
|
机器学习/深度学习 数据采集 人工智能
微调之后还能做什么?大模型后训练全链路技术解析
本文探讨了后训练的重要性、方法以及最新进展。文章将包含理论分析与实际操作指南,适合希望深入了解并应用这些技术的开发者。
1743 18
微调之后还能做什么?大模型后训练全链路技术解析
|
5月前
|
边缘计算 缓存 人工智能
EdgeShard:通过协作边缘计算实现高效的大语言模型推理——论文解读
EdgeShard是一种基于协作边缘计算的大语言模型(LLM)推理框架,旨在解决LLM在云端部署面临的延迟高、带宽压力大和隐私泄露等问题。通过将LLM分片部署在多个边缘设备上,结合云边协同与设备间协作,EdgeShard实现了高效的模型推理。其核心创新包括:联合设备选择与模型划分优化、支持流水线并行与微批处理、提出EdgeShard-No-Bubbles策略以减少设备空闲时间,从而显著提升推理吞吐量并降低延迟。实验表明,EdgeShard在异构边缘设备上可实现高达50%的延迟降低和2倍的吞吐量提升,支持全精度模型推理而无精度损失,为资源受限的边缘环境提供了高效的LLM部署方案。
1042 2
|
9月前
|
存储 机器学习/深度学习 缓存
vLLM 核心技术 PagedAttention 原理详解
本文系统梳理了 vLLM 核心技术 PagedAttention 的设计理念与实现机制。文章从 KV Cache 在推理中的关键作用与内存管理挑战切入,介绍了 vLLM 在请求调度、分布式执行及 GPU kernel 优化等方面的核心改进。PagedAttention 通过分页机制与动态映射,有效提升了显存利用率,使 vLLM 在保持低延迟的同时显著提升了吞吐能力。
4833 20
vLLM 核心技术 PagedAttention 原理详解
|
11月前
|
机器学习/深度学习 数据处理
大语言模型中的归一化技术:LayerNorm与RMSNorm的深入研究
本文分析了大规模Transformer架构(如LLama)中归一化技术的关键作用,重点探讨了LayerNorm被RMSNorm替代的原因。归一化通过调整数据量纲保持分布形态不变,提升计算稳定性和收敛速度。LayerNorm通过均值和方差归一化确保数值稳定,适用于序列模型;而RMSNorm仅使用均方根归一化,省略均值计算,降低计算成本并缓解梯度消失问题。RMSNorm在深层网络中表现出更高的训练稳定性和效率,为复杂模型性能提升做出重要贡献。
2547 14
大语言模型中的归一化技术:LayerNorm与RMSNorm的深入研究