LLM 加速技巧:Muti Query Attention

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
检索分析服务 Elasticsearch 版,2核4GB开发者规格 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。在大语言模型时代被广泛使用,很多LLM都采用了MQA,如Falcon、PaLM、StarCoder等。

在介绍MQA 之前,我们先回顾一下传统的多头注意力

Multi-Head Attention(MHA)

多头注意力是transformer 模型的默认注意力机制,如下图所示:

在文本生成方面,基于transformer 的自回归语言模型存在一个问题。在训练过程中可以获得真实的目标序列,并且可以有效地实现并行化。

但是在推理过程中,每个位置的查询都要处理在该位置或之前生成的所有键值对。也就是说自注意力层在特定位置的输出影响下一个令牌的生成,所以无法并行化,这使得推理变得非常的慢。

下图是基于transformer 解码器的自回归语言模型中自注意层的解码过程:

 defMHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q=tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis=2)], axis=2)
     new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis=2)], axis=2)
     logits=tf.einsum("bhk, bhmk−>bhm", q, new_K)
     weights=tf.softmax(logits)
     O=tf.einsum("bhm, bhmv−>bhv", weights, new_V)
     Y=tf.einsum("bhv, hdv−>bd", O, P_o)
     returnY, new_K, new_V

其中:

X:当前的输入张量,m为当前步,m+1为阶跃,形状为[b, d]

P_q, P_k:查询和键投影张量,形状为[h, d, k]

P_v:值投影张量,形状为[h, d, v]

P_o:学习到的线性投影,形状为[h, d, v]

Prev_K:上一步的关键张量,形状为[b, h, m, k]

Prev_V:前一步的Value张量,形状为[b, h, m, v]

new_K:加上当前步的键张量,形状为[b, h, m+1, k]

new_V:加了当前步长的Value张量,形状为[b, h, m+1, v]

维度表示如下:

M:先前执行的步骤数

B:批量大小

D:输入和输出的尺寸

H:注意力头数

k:Q,K张量的另一个维度

v: v张量的另一个维度

Multi-Query Attention(MQA)

MQA是多头注意的一种变体。

MQA的方法是保持Q的初始头数,但K和V只有一个头,这意味着所有Q个头共享相同的K和V,因此称为Multi-Query,如下图所示:

从论文的解释中可以看到,MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。

MQA解码过程的代码本质上与MHA的代码相同,只是从中删除了表示头部尺寸的字母“h”。K, V, P_k, P_v的和方程:

 defMQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q=tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K=tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis=2)], axis=2)
     new_V=tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis=2)], axis=2)
     logits=tf.einsum("bhk, bmk−>bhm", q, new_K)
     weights=tf.softmax(logits)
     O=tf.einsum("bhm, bmv−>bhv", weights, new_V)
     Y=tf.einsum("bhv, hdv−>bd", O, P_o)
     returnY, new_K, new_V

上面都是tf的代码,如果阅读有问题,我从 llm-foundry项目中找到了pytorch的代码实现,这里只做个摘抄,有兴趣的请看原项目

 classMultiheadAttention(nn.Module):

     def__init__(
             self,
             d_model: int,
             n_heads: int,
             device: str
         ):
         """
         Multi Head init func.

         Args:
             d_model (int): hidden state size, e.g. 768
             n_heads (int): 设定的注意力头数, e.g. 8
             device (str): _description_
         """
         super().__init__()

         self.d_model=d_model
         self.n_heads=n_heads

         self.Wqkv=nn.Linear(                       # Multi-Head Attention 的创建方法
             self.d_model, 
             3*self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
             device=device
         )                                            # (d_model, 3 * d_model)
         self.attn_fn=scaled_multihead_dot_product_attention
         self.out_proj=nn.Linear(
             self.d_model, 
             self.d_model, 
             device=device
         )

     defforward(
         self,
         x
     ):
         """
         forward func.

         Args:
             x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)

         Returns:
             _type_: _description_
         """
         qkv=self.Wqkv(x)                            # (1, 768, 3 * 768)

         query, key, value=qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)
             3, 
             dim=2
         )     

         context, attn_weights, past_key_value=self.attn_fn(
             query,
             key,
             value,
             self.n_heads
         )                                             # (1, 512, 768)

         returnself.out_proj(context), attn_weights, past_key_value


 classMultiQueryAttention(nn.Module):
     """Multi-Query self attention.

     Using torch or triton attention implemetation enables user to also use
     additive bias.
     """

     def__init__(
         self,
         d_model: int,
         n_heads: int,
         device: Optional[str] =None,
     ):
         super().__init__()

         self.d_model=d_model
         self.n_heads=n_heads
         self.head_dim=d_model//n_heads

         self.Wqkv=nn.Linear(                           # Multi-Query Attention 的创建方法
             d_model,
             d_model+2*self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
             device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
         )

         self.attn_fn=scaled_multihead_dot_product_attention
         self.out_proj=nn.Linear(
             self.d_model, 
             self.d_model, 
             device=device
         )
         self.out_proj._is_residual=True  # type: ignore

     defforward(
         self,
         x,
     ):
         qkv=self.Wqkv(x)                                           # (1, 512, 960)

         query, key, value=qkv.split(                               # query -> (1, 512, 768)
             [self.d_model, self.head_dim, self.head_dim],            # key   -> (1, 512, 96)
             dim=2                                                    # value -> (1, 512, 96)
         )

         context, attn_weights, past_key_value=self.attn_fn(
             query,
             key,
             value,
             self.n_heads,
             multiquery=True,
         )

         returnself.out_proj(context), attn_weights, past_key_value

从代码中可以看到所有 头之间共享一份 key 和 value 的参数,但是如何将这 1 份参数同时让 8 个头都使用呢?

代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享,主要是这个函数:scaled_multihead_dot_product_attention

 defscaled_multihead_dot_product_attention(
         query,
         key,
         value,
         n_heads,
         past_key_value=None,
         softmax_scale=None,
         attn_bias=None,
         key_padding_mask=None,
         is_causal=False,
         dropout_p=0.0,
         training=False,
         needs_weights=False,
         multiquery=False,
     ):
     q=rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
     kv_n_heads=1ifmultiqueryelsen_heads
     k=rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery 
                                                                     # (1, 512, 96) -> (1, 1, 96, 512)  if multiquery
     v=rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery 
                                                                     # (1, 512, 96) -> (1, 1, 512, 96)  if multiquery

     attn_weight=q.matmul(k) *softmax_scale                       # (1, 8, 512, 512)
     attn_weight=torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)

     out=attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
     out=rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)

     returnout, attn_weight, past_key_value

MQA指标测试

MQA能在多大程度上提高速度?让我们看看原文中提供的结果图表:

从上表可以看出,MQA在编码器上的速度提升不是很显著,但在解码器上的速度提升是相当显著的。

论文中也有关于质量的实验,结果表明MQA的性能与基线相比只是稍微低一些。降低应该是肯定的因为毕竟共享了参数,但是只要再可接受范围内并且能够大量提升速度这个降低就是可以接受的,对吧。

为什么MQA可以实现推理加速?

在MQA中,键张量和值张量的大小分别为b k和b v,而在MHA中,键张量和值张量的大小分别为b h k和b h v,其中h表示头的个数。

MQA通过以下方法实现推理加速:

1、KV缓存大小减少了h(头数量),这意味着需要存储在GPU内存中的张量也减少了。节省的空间可以用来增加批大小,从而提高效率。

2、减少了从内存中读取的数据量,从而减少了计算单元的等待时间,提高了计算利用率。

3、MQA有一个相对较小的KV数量,可以放入缓存(SRAM)中。MHA则需要较大的KV数量,不能完全存储在缓存中,需要从GPU内存(DRAM)读取,这很耗时。

总结

MQA是在2019年提出的,当时的应用还没有那么广泛。这是因为以前的模型不需要关心这些方面,例如,LSTM只需要维护一个状态,而不需要保留任何缓存。

当transformer最初被提出时,它主要用于Seq2Seq任务,特别是在Encoder-Decoder模型中。由于模型的规模不是很大,也并且没有太多的实际需求,所以MQA并没有引起太多的关注。

直到近年来(尤其是2023年开始)基于transformer的大型语言模型(如GPT)得到广泛应用后,推理的瓶颈才被人们重视。所以MQA才被发现非常有用,这主要是由于对大规模gpt式生成模型的实际需求。

最后我们再回顾以下这个论文:

https://avoid.overfit.cn/post/877de0f5a56d478d8133d75a05064e7e

作者:Florian June

目录
相关文章
|
机器学习/深度学习 人工智能 自然语言处理
LLM系列 | 11: 基于ChatGPT构建智能客服系统(query分类&安全检查&防注入)
本文主要介绍如何使用ChatGPT对智能客服领域中的客户咨询进行分类。此外还补充构建真实应用中如何对用户咨询内容和模型生成内容进行安全检查及其如何预防用户注入。
|
2月前
|
前端开发 机器人 API
前端大模型入门(一):用 js+langchain 构建基于 LLM 的应用
本文介绍了大语言模型(LLM)的HTTP API流式调用机制及其在前端的实现方法。通过流式调用,服务器可以逐步发送生成的文本内容,前端则实时处理并展示这些数据块,从而提升用户体验和实时性。文章详细讲解了如何使用`fetch`发起流式请求、处理响应流数据、逐步更新界面、处理中断和错误,以及优化用户交互。流式调用特别适用于聊天机器人、搜索建议等应用场景,能够显著减少用户的等待时间,增强交互性。
651 2
|
2月前
|
机器学习/深度学习 人工智能 运维
企业内训|LLM大模型在服务器和IT网络运维中的应用-某日企IT运维部门
本课程是为某在华日资企业集团的IT运维部门专门定制开发的企业培训课程,本课程旨在深入探讨大型语言模型(LLM)在服务器及IT网络运维中的应用,结合当前技术趋势与行业需求,帮助学员掌握LLM如何为运维工作赋能。通过系统的理论讲解与实践操作,学员将了解LLM的基本知识、模型架构及其在实际运维场景中的应用,如日志分析、故障诊断、网络安全与性能优化等。
100 2
|
2月前
|
机器学习/深度学习 数据采集 人工智能
文档智能 & RAG 让AI大模型更懂业务 —— 阿里云LLM知识库解决方案评测
随着数字化转型的深入,企业对文档管理和知识提取的需求日益增长。阿里云推出的文档智能 & RAG(Retrieval-Augmented Generation)解决方案,通过高效的内容清洗、向量化处理、精准的问答召回和灵活的Prompt设计,帮助企业构建强大的LLM知识库,显著提升企业级文档管理的效率和准确性。
|
27天前
|
机器学习/深度学习 人工智能 自然语言处理
深挖大模型幻觉!哈佛大学最新报告:LLM等价于众包,只是在输出网络共识
大型语言模型(LLM)如ChatGPT正改变人机交互,但在生成看似真实的错误信息方面存在“幻觉”问题。这种现象源于LLM依赖统计概率而非语义理解,导致在处理争议或冷门话题时易出错。研究显示,LLM的准确性高度依赖于训练数据的质量和数量。尽管如此,LLM仍具巨大潜力,需持续优化并保持批判性使用。
47 12
|
30天前
|
人工智能 自然语言处理
大模型在装傻!谷歌苹果最新发现:LLM知道但不告诉你,掌握知识比表现出来的多
在AI领域,大模型(LLM)展现出了惊人的进步,但在谷歌和苹果的最新研究中,发现这些模型有时会故意“装傻”,即使已知正确答案也不告知用户。这种“隐藏智慧”现象揭示了大模型可能具备超出表面表现的深层能力,对AI评估与应用提出了新挑战,同时也带来了设计更高效模型的新机遇。论文链接:https://arxiv.org/pdf/2410.02707
42 11
|
1月前
|
自然语言处理 开发者
多模态大模型LLM、MLLM性能评估方法
针对多模态大模型(LLM)和多语言大模型(MLLM)的性能评估,本文介绍了多种关键方法和标准,包括模态融合率(MIR)、多模态大语言模型综合评估基准(MME)、CheckList评估方法、多模态增益(MG)和多模态泄露(ML),以及LLaVA Bench。这些方法为评估模型的多模态和多语言能力提供了全面的框架,有助于研究者和开发者优化和改进模型。
133 5
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
大模型强崩溃!Meta新作:合成数据有剧毒,1%即成LLM杀手
在人工智能领域,大型语言模型(LLMs)的快速发展令人瞩目,但递归生成数据可能导致“模型崩溃”。Meta的研究揭示,模型在训练过程中会逐渐遗忘低概率事件,导致数据分布偏差。即使少量合成数据(如1%)也会显著影响模型性能,最终导致崩溃。研究强调保留原始数据的重要性,并提出社区合作和技术手段来区分合成数据和真实数据。论文地址:https://www.nature.com/articles/s41586-024-07566-y
82 2
|
1月前
|
人工智能 自然语言处理 算法
政务培训|LLM大模型在政府/公共卫生系统的应用
本课程是TsingtaoAI公司面向某卫生统计部门的政府职员设计的大模型技术应用课程,旨在系统讲解大语言模型(LLM)的前沿应用及其在政府业务中的实践落地。课程涵盖从LLM基础知识到智能化办公、数据处理、报告生成、智能问答系统构建等多个模块,全面解析大模型在卫生统计数据分析、报告撰写和决策支持等环节中的赋能价值。
74 2
|
2月前
|
人工智能 自然语言处理 运维
前端大模型应用笔记(一):两个指令反过来说大模型就理解不了啦?或许该让第三者插足啦 -通过引入中间LLM预处理用户输入以提高多任务处理能力
本文探讨了在多任务处理场景下,自然语言指令解析的困境及解决方案。通过增加一个LLM解析层,将复杂的指令拆解为多个明确的步骤,明确操作类型与对象识别,处理任务依赖关系,并将自然语言转化为具体的工具命令,从而提高指令解析的准确性和执行效率。