1、分层归一化的“层级-Flash”(精确派)
痛点:跨 tile 的 softmax 需要全局一致的归一化;Flash 用在线 log-sum-exp 解决,但层级只有一层。
思路:做成层级前缀和:
Level-1:SM 内 tile 归一化(寄存器/共享内存)
Level-2:Block 级汇总(共享内存/片上 SRAM)
Level-3:Grid 级全局拼接(一次全局归并+重标定)
收益:对超长上下文(>256k tokens)时,全局重标定的通信次数从 O(#tiles) 近似压到 O(#levels)。
代价/风险:实现复杂,需严格证明数值等价;需要良好的 block 排布策略。
2、自适应动态分块(精确派)
痛点:固定 tile 大小在序列分布极不均衡时不是最优(有的区段“信息密”,有的“稀”)。
思路:运行时用低开销统计(如每 tile 的最大点积/方差)动态调整 tile 尺寸与扫描顺序,并在归一化时带上对应缩放。
收益:IO 更接近下界;高密度子段更小 tile、更多并行;稀疏子段更大 tile、减少调度开销。
代价/风险:需要一个轻量“探针轮”或边算边估计的控制逻辑。
3、 “先筛后精”的可证明上/下界筛除(近似→可控误差)
痛点:大量 QK^T 的内积贡献极小,照算不值当。
思路:对每个 Q-tile 维护上界(如 ||q||·||K_tile||)和已积累下界;若上界低于“仍可能改变前 Top-k 权重”的阈值,直接跳过该 K-tile。
收益:在长上下文/主题块明显时,大幅减少无效 tile 乘加;误差由边界控制。
代价/风险:需要严谨的阈值与“误差-召回”曲线设计;对极度均匀分布收益有限。
4、 分段前缀-softmax 的严格等价实现(精确派)
痛点:跨 tile 归一化仍然需要合并多次中间状态。
思路:把 log-sum-exp 状态 (m,d)(m, d)(m,d)(最大值与指数和)做成可并联的半群:
(m, d) ⊕ (m′, d′) = (max(m, m′), d · e^(m - m) + d′ · e^(m′ - m)), m* = max(m, m′)
支持任意顺序/拓扑的归并(像 prefix-scan)。
收益:自由调度 tile 的同时保持数学等价;便于多卡/多 SM 并行。
代价/风险:工程上要保证溢出与舍入误差的下界(建议 BF16/FP32 累加)。
5、 KV-Cache 的在线压缩/重构(精确派+系统优化)
痛点:推理阶段 KV-cache 逐 token 涨;IO 成新瓶颈。
思路:对“冷”KV-tile 使用乘积量化/低秩重构的可逆存储(热 tile 原精度,冷 tile 压缩),在访问到时快速解码到共享内存再参与计算;频度-温度策略动态迁移冷热。
收益:显存与带宽压力显著下降,几乎不动模型结构。
代价/风险:需要确保解码延迟 < 省下的 IO;量化误差要对 softmax 稳定性友好(建议 value 侧更高精度)。
6、 异构精度的算子级布局(精确派)
痛点:一刀切精度不是最优。
思路:
QKTQK^TQKT 点积用 FP8/INT8 输入 + FP32 累加
log-sum-exp 的状态 (m,d)(m,d)(m,d) 强制 FP32
pVpVpV 的 V 参与乘法用 BF16,累加用 FP32
收益:显著降低带宽与存储,几乎不损精度
代价/风险:需要张量核路径稳定+校准(per-tile 缩放更稳)
7、 2.5D 张量并行的 Flash 排程(精确派+多卡)
痛点:数据并行/张量并行对 Attention 的通信开销大。
思路:把 Q-tiles 做行分片,K/V-tiles 做列分片,引入2.5D 网格通信(环形+树形混合),并让第 4) 的半群归并跨卡前缀合一。
收益:在多卡(甚至多机)下延续 Flash 的 IO-aware 优势;长序列扩展能力更强。
代价/风险:通信拓扑与负载均衡复杂,要有拓扑感知调度器。
8、注意力调度器:分数-引导的 Tile 重排(精确派→轻近似)
痛点:默认顺序扫 tile 不是信息论最优。
思路:用极低成本的粗粒度打分(例如上界估计或低秩预热)先“猜”出高贡献的 K/V-tiles,优先算高分块,让归一化的尺度更早稳定,减小后续数值漂移与无效工作。
收益:更少的回溯与重标定,端到端时延下降。
代价/风险:需要保证重排不会破坏等价性(等价派需全量算,只是排序不同)。
示例:把 1、4、8 的思路串在一起
\初始状态:m 表示当前最大值,d 表示累积的指数和,out 是输出累加
state = (m=-inf, d=0, out=0)
\先做个粗打分,把最可能贡献大的 K-tile 放前面算(方向 8)
candidates = rank_tiles_by_upper_bound(Q, K_tiles)
for tile in candidates:
\方向 1:支持层级/可重排;方向 6:低比特输入 + FP32 累加
S = Q_tile @ K_tile.T / sqrt(dk)
方向 4:把每个 tile 的 log-sum-exp 状态拿出来
(m_t, d_t) = logsumexp_state(S)
state.(m,d) = semigroup_merge(state.(m,d), (m_t, d_t))
\ 做一次分段归一化
P = exp(S - state.m) / (state.d_partial?)
\ 输出累加;这里可以顺便筛除掉低贡献的计算
out += P @ V_tile
\ 最后一步:把所有 block 的 (m,d,out) 做一次全局 semigroup 归并
\ 得到和完整 Attention 一样的结果
什么时候选哪种组合?
训练/对齐阶段:优先 1/2/4/6/7(完全等价 & 可扩展)
超长上下文推理:1/2/5/7 必选,必要时叠加 3/8 做可控近似,换低延迟
边缘/移动端:5/6/8 组合,先把 IO 和精度能耗打下来
。。。。。
FlashAttention ,下一代可以做的是:
更聪明地分块(自适应/层级/重排)
更稳健地跨块合并(可并联的 log-sum-exp 半群)
更经济地存取(KV 在线压缩与异构精度)
更大规模地协同(2.5D 并行与拓扑感知)