attention计算过程的一些细节

简介: attention计算的一些细节解释

最近,有粉丝问我,attention结构中计算qkv的时候,为什么要做kvcache呢?他看了一些文章,没看懂。

为什么要做kvcache?

假设模型的输入序列长度是2,隐藏层的维度是H,那么q、k、v的维度分别是[2, H]

假设它们的值分别是:

q=[q1,
   q2]
k=[k1,
   k2]
v=[v1,
   v2]

那么首先q*k的结果为:

[q1*k1, q1*k2
 q2*k1, q2*k2]

然后需要做一个mask,只留下下三角的值,其他值都取0,得到:

[q1*k1, 0
 q2*k1, q2*k2]

为什么要做mask,我认为是要和训练时的规则保持一致,因为训练的时候,是认为每个token只能看到它前面的词的。

然后计算qk*v:

[q1*k1*v1,
 q2*k1*v1+q2*k2*v2]

完成后续的计算可以预测得到1个新的token。如果还需要继续预测下一个词,在下一次计算的时候我们假设q、k、v为:

q=[q1,
   q2,
   q3]
k=[k1,
   k2,
   k3]
v=[v1,
   v2,
   v3]

同样得到q*k为:

[q1*k1, 0,     0
 q2*k1, q2*k2, 0,
 q3*k1, q3*k2, q3*k3]

qk*v为:

[q1*k1*v1,
 q2*k1*v1+q2*k2*v2,
 q3*k1*v1+q3*k2*v2+q3*k3*v3]

可以看到,第2次计算得到的qk相比于第1次的qk只是多了第3行。而第3行的值是q3*[k1, k2, k3],所以为了避免重复计算,我们只需要在第2次计算的时候,只计算新token对应的q3和k3,然后把k3和第1次计算得到的[k1, k2]拼接起来即可,[k1, k2]就是 k cache。

同样可以发现,第2次计算得到的qkv相比于第1次的qkv只是多了第3行。而第3行的值是qk*[v1, v2, v3],所以为了避免重复计算,我们只需要在第2次计算的时候,只计算新token对应的v3,然后把v3和第1次计算得到的[v1, v2]拼接起来即可,[v1, v2]就是 v cache。

以此类推,在后续的增量推理过程中,每次只需要计算新token的q、k、v,然后利用之前缓存的kv cache计算qk和qkv。

transformer是怎样预测出下一个词的?

首先,从数学层面来讲,是这样计算的:

首先,假设输入序列的长度是L,隐藏层的特征维度是H,词汇表的长度是V,那么在计算qkv的过程中,输入x的shape变化如下:

q*k:(L, H)x(H, L)->(L, L)

qk*v:(L, L)x(L, H)->(L, H)

然后再经过forward layer的一系列全连接层,得到的输出shape为(L, V),而它的最后一个分量,也就是output[L-1],就是预测结果的概率分布。

那么怎么理解这个计算过程呢?这个就可以有很多答案了,我一般是这么给别人解释的:首先在计算q*k的时候,qk的最后一个分量是用最后一个词去和其他词的key值做乘法,这一步相当于计算最后一个词和句子中每个词的相关性,然后乘以v就相当于把最后一个词和其他词的相关性进行一个组合,后面再通过多个全连接层进行上下文理解这个词在整个句子中的含义,并预测出下一个词。

这里又引入了另一个问题,既然在首次计算时,只用到了最后一个分量,为什么还要计算qk和qkv的第1到第L-1个分量的值呢?这是因为大模型由多个decoder layer叠加组成。第1个decoder输出的结果还需要作为x输入给第2个decoder layer,进行多轮"思考"。再具体一点,我们还是假设输入序列长度是2,经过第1个decoder layer后输出为:

[h1,
 h2]

那么它再作为输入传给第2个decoder layer,第2个decoder layer计算得到的qkv是:

[q1*k1*v1,
 q2*k1*v1+q2*k2*v2]

它的最后一个分量是q2k1v1+q2k2v2,其中的k1、v1都和h1相关,所以做首次计算(也就是我们常说的全量计算)时,qk和qkv的每个分量都要计算。

大家还有什么疑问呢?欢迎讨论哦!

目录
相关文章
|
11月前
|
存储 算法 调度
|
10月前
|
算法 开发者 Python
MindIE DeepSeek MTP特性定位策略
最近MindIE开始支持DeepSeek MTP(multi token prediction)特性了,用于推理加速。但是有些开发者打开MTP开关后,没有发现明显的性能提升。这篇文章提供一种定位策略。
240 1
|
11月前
|
人工智能 PyTorch 算法框架/工具
ACK AI Profiling:从黑箱到透明的问题剖析
本文从一个通用的客户问题出发,描述了一个问题如何从前置排查到使用AI Profiling进行详细的排查,最后到问题定位与解决、业务执行过程的分析,从而展现一个从黑箱到透明的精细化的剖析过程。
|
11月前
|
机器学习/深度学习 存储 人工智能
浅入浅出——生成式 AI
团队做 AI 助理,而我之前除了使用一些 AI 类产品,并没有大模型相关的积累。故先补齐一些基本概念,避免和团队同学沟通起来一头雾水。这篇文章是学习李宏毅老师《生成式 AI 导论》的学习笔记。
973 27
浅入浅出——生成式 AI
|
11月前
|
开发框架 人工智能 Java
破茧成蝶:阿里云应用服务器让传统 J2EE 应用无缝升级 AI 原生时代
本文详细介绍了阿里云应用服务器如何助力传统J2EE应用实现智能化升级。文章分为三部分:第一部分阐述了传统J2EE应用在智能化转型中的痛点,如协议鸿沟、资源冲突和观测失明;第二部分展示了阿里云应用服务器的解决方案,包括兼容传统EJB容器与微服务架构、支持大模型即插即用及全景可观测性;第三部分则通过具体步骤说明如何基于EDAS开启J2EE应用的智能化进程,确保十年代码无需重写,轻松实现智能化跃迁。
778 42
|
11月前
|
存储 SQL 大数据
从 o11y 2.0 说起,大数据 Pipeline 的「多快好省」之道
SLS 是阿里云可观测家族的核心产品之一,提供全托管的可观测数据服务。本文以 o11y 2.0 为引子,整理了可观测数据 Pipeline 的演进和一些思考。
553 34
|
11月前
|
Kubernetes 调度 开发者
qwen模型 MindIE PD分离部署问题定位
使用MindIE提供的PD分离特性部署qwen2-7B模型,使用k8s拉起容器,参考这个文档进行部署:https://www.hiascend.com/document/detail/zh/mindie/100/mindieservice/servicedev/mindie_service0060.html,1个Prefill,1个Decode。 最后一步测试推理请求的时候,出现报错:model instance has been finalized or not initialized。
667 1
|
11月前
|
消息中间件 运维 监控
加一个JVM参数,让系统可用率从95%提高到99.995%
本文针对一个高并发(10W+ QPS)、低延迟(毫秒级返回)的系统因内存索引切换导致的不稳定问题,深入分析并优化了JVM参数配置。通过定位问题根源为GC压力大,尝试了多种优化手段:调整MaxTenuringThreshold、InitialTenuringThreshold、AlwaysTenure等参数让索引尽早晋升到老年代;探索PretenureSizeThreshold和G1HeapRegionSize实现索引直接分配到老年代;加速索引复制过程以及升级至JDK11使用ZGC。
766 82
加一个JVM参数,让系统可用率从95%提高到99.995%