Mosaic:面向超长序列的多GPU注意力分片方案

简介: 本文剖析Transformer中“二次方注意力瓶颈”的成因与工程破解之道,聚焦Mosaic提出的多轴注意力分片方案。针对长序列内存爆炸问题,Mosaic通过灵活路由不同轴至本地或分布式后端(如Ring、Mesh2D),实现高效计算与通信平衡,尤其适用于表格等多维数据场景,显著降低显存占用且不侵入模型代码。

Transformer的"二次方注意力瓶颈"的问题是老生常谈了。这个瓶颈到底卡在哪实际工程里怎么绕过去?本文从一个具体问题出发,介绍Mosaic这套多轴注意力分片方案的设计思路。

注意力的内存困境

注意力机制的计算公式:

 Attention(Q, K, V) = softmax(QKᵀ / √d) × V

问题出在 QKᵀ 这个矩阵上,它的形状是

(序列长度 × 序列长度)

拿150,000个token的序列算一下:

 Memory = 150,000² × 4 bytes = 90 billion bytes ≈ 84 GB

这只是注意力权重本身的开销,而且还是单层、单头。A100的显存上限是80GB,放不下就是放不下。

现有方案的局限

FlashAttention 它通过分块计算,不需要把完整的注意力矩阵实例化出来,内存复杂度从O(n²)降到O(n)。单卡场景下效果很好,但问题是整个序列还是得塞进同一张GPU。

Ring Attention 换了个思路:把序列切片分到多张GPU上,每张卡持有一部分Q,K和V在GPU之间像传令牌一样轮转,一维序列处理起来是很不错的。

但是多维怎么办?

比如处理表格数据的Transformer,输入张量形状是

(batch, rows, features, embed)

。模型需要在不同维度上做注意力:features维度只有5个token,rows维度却有150,000个。前者单卡轻松搞定,后者则必须分片。

现有的库都没法干净地处理这种多轴场景。手写的话,每个轴要单独写分片逻辑,进程组管理、张量reshape全得自己来。代码会变得很脏。

Mosaic的设计

Mosaic本质上是个协调层,负责把不同的注意力轴路由到合适的计算后端:

 import mosaic

# Small axis: run locally
feature_attn = mosaic.MultiAxisAttention(  
    embed_dim=96,   
    num_heads=4,  
    attention_axis=2,    # features dimension
    backend="local"      # no communication needed
)

# Large axis: shard across GPUs
row_attn = mosaic.MultiAxisAttention(  
    embed_dim=96,   
    num_heads=4,  
    attention_axis=1,    # rows dimension
    backend="ring"       # ring attention across GPUs
 )

底层Mosaic会自动处理轴的置换、QKV投影前的reshape、后端分发、以及计算完成后张量形状的还原。模型代码保持清晰,分布式的复杂性被封装掉了。

Ring Attention的工作机制

核心思想其实很直接:不需要同时持有全部的K和V。可以分批计算注意力分数,逐步累积,最后再做归一化。

比如说4张GPU的情况下流程是这样的:

 Initial state:  
  GPU 0: Q₀, K₀, V₀  
  GPU 1: Q₁, K₁, V₁    
  GPU 2: Q₂, K₂, V₂  
  GPU 3: Q₃, K₃, V₃

Step 1: Each GPU computes attention with its local K, V  
  GPU 0: score₀₀ = Q₀ @ K₀ᵀ  
  ...

Step 2: Pass K, V to the next GPU in the ring  
  GPU 0 receives K₃, V₃ from GPU 3  
  GPU 0 sends K₀, V₀ to GPU 1  

Step 3: Compute attention with received K, V  
  GPU 0: score₀₃ = Q₀ @ K₃ᵀ  
  Accumulate with score₀₀

Repeat for all chunks...

 Final: Each GPU has complete attention output for its Q chunk

单卡内存占用变成O(n²/p),p是GPU数量。8张卡的话内存需求直接砍到1/8。150k序列从84GB降到约10GB每卡。

Mesh2D:更激进的分片

序列特别长的时候Ring Attention的线性分片可能还不够,这时候可以用Mesh2D把Q和K都切分了:

 4 GPUs arranged in 2×2 mesh:

          K₀    K₁  
       ┌──────┬──────┐  
  Q₀   │GPU 0 │GPU 1 │  
       ├──────┼──────┤  
  Q₁   │GPU 2 │GPU 3 │  
       └──────┴──────┘  

 Each GPU computes one tile of QKᵀ

内存复杂度降到O(n²/p²)。64张卡组成8×8网格时,每卡内存需求下降64倍。

 attn=mosaic.MultiAxisAttention(  
     embed_dim=128,   
     num_heads=8,  
     attention_axis=1,  
     backend="mesh2d",  
     mesh_shape=(8, 8)  
 )

感知集群拓扑的组合策略

在实际部署环境里,不同GPU之间的通信带宽差异很大。节点内GPU走NVLink能到900 GB/s,跨节点通过InfiniBand通常只有200 GB/s左右。

ComposedAttention

就是针对这种拓扑特征设计的:

 # 4 nodes × 8 GPUs = 32 total
 composed = mosaic.ComposedAttention(  
     mesh_shape=(4, 8),       # (nodes, gpus_per_node)
     head_parallel=True,      # Split heads across nodes (slow link)
     seq_parallel="ring"      # Ring within nodes (fast link)
 )

需要更精细控制的话,可以用

HierarchicalAttention

 hier = mosaic.HierarchicalAttention(  
     intra_node_size=8,  
     intra_node_strategy="local",   # Compute locally within node
     inter_node_strategy="ring"     # Ring between node leaders
 )

重通信走快链路轻通信才跨节点。

实现细节

整个库大约800行Python,核心代码如下:

 class MultiAxisAttention(nn.Module):  
    def forward(self, x):  
        # 1. Move attention axis to seq position
        x, inv_perm = self._permute_to_seq(x)  

        # 2. Flatten batch dims, project QKV
        x = x.view(-1, seq_len, embed_dim)  
        qkv = self.qkv_proj(x).view(batch, seq, 3, heads, head_dim)  
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)  

        # 3. Dispatch to backend
        out = self._attn_fn(q, k, v)  # local, ring, or mesh2d

        # 4. Project output, restore shape
        out = self.out_proj(out.transpose(1, 2).reshape(...))  
         return out.permute(inv_perm)

后端封装了现有的成熟实现:

local

后端调用

F.scaled_dot_product_attention

(也就是FlashAttention),

ring

后端用ring-flash-attn库的

ring_flash_attn_func

mesh2d

是自定义的all-gather加SDPA,所有的底层都跑的是FlashAttention内核。

所有后端统一用FlashAttention的融合GEMM+softmax实现。后端函数在初始化时就绑定好,前向传播不做分支判断。张量操作尽量用

x.view()

而不是

x.reshape()

,保持内存连续性。集合通信的目标张量预分配好,避免

torch.cat

的开销。模块级别做导入不在每次前向传播时产生import开销。

快速上手

安装:

 pip install git+https://github.com/stprnvsh/mosaic.git

 # With ring attention support
 pip install flash-attn ring-flash-attn

单节点启动:

 torchrun --nproc_per_node=4 train.py

多节点的话:

 # Node 0
 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \  
          --master_addr=192.168.1.100 --master_port=29500 train.py

 # Node 1
 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \  
          --master_addr=192.168.1.100 --master_port=29500 train.py

训练脚本示例:

 import mosaic  
import torch.distributed as dist

dist.init_process_group("nccl")  
ctx = mosaic.init(sp_size=dist.get_world_size())

model = MyModel().to(ctx.device)

# Data is pre-sharded: each GPU has seq_total / world_size tokens
x_local = load_my_shard()  
 out = model(x_local)  # Communication handled by Mosaic

总结

最后,Mosaic不会自动并行化模型(这个用nnScaler),不管数据并行(PyTorch DDP/FSDP的事),也不处理模型分片(交给FSDP或Megatron)。

Mosaic专注于一件事:多轴注意力的分片路由,这套方案最初是给 nanoTabPFN 做的,一个表格数据Transformer。

这个模型要同时在rows(150k个)和features(5个)两个维度做注意力。标准Ring Attention对维度语义没有感知,它只认序列这个概念,分不清rows和features的区别。

所以Mosaic需求很明确:小轴本地算,大轴分布式算,轴的路由逻辑不能侵入模型代码,有兴趣的可以试试。

https://avoid.overfit.cn/post/791e0f30540e4d289a43d01d383e8ab2

作者:Pranav Sateesh

相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
目录
相关文章
|
4月前
|
人工智能 运维 安全
阿里云GPU服务器全解析:租赁价格、GPU卡规格及问题解答FAQ
阿里云GPU云服务器(EGS)依托高性能GPU芯片与神龙架构,提供弹性灵活、安全稳定的算力支持,广泛适用于AI训练推理、图形渲染、科学仿真等场景。支持多种计费模式与丰富GPU规格,兼顾成本与性能,并集成机密计算、自动运维、生态兼容等核心优势,助力企业高效构建高性能计算环境。
2366 1
|
5月前
|
API 数据库 Docker
向量搜索升级指南:FAISS 到 Qdrant 迁移方案与代码实现
FAISS 适合实验,但生产环境痛点诸多:无元数据支持、非服务化、难持久化。迁移到 Qdrant 后,实现开箱即用的向量数据库能力,支持混合搜索、过滤、持久化与高效 API,大幅提升系统稳定性与开发效率,真正打通从研究到生产的闭环。
326 6
向量搜索升级指南:FAISS 到 Qdrant 迁移方案与代码实现
|
4月前
|
前端开发 算法
深度研究Agent架构解析:4种Agent架构介绍及实用Prompt模板
本文系统梳理了深度搜索Agent的主流架构演进:从基础的Planner-Only,到引入评估反馈的双模块设计,再到支持层次化分解的递归式ROMA方案。重点解析了问题拆解与终止判断两大核心挑战,并提供了实用的Prompt模板与优化策略,为构建高效搜索Agent提供清晰路径。
2050 10
深度研究Agent架构解析:4种Agent架构介绍及实用Prompt模板
|
4月前
|
Linux 数据安全/隐私保护 iOS开发
openSUSE-Leap-15.0-DVD-x86_64离线安装步骤 附安装包
准备8GB以上U盘,下载openSUSE Leap 15.0镜像并用Rufus或dd工具写入。进BIOS设U盘启动,关闭Secure Boot。安装时选中文、跳过网络,使用自动分区,设置root密码和管理员用户,全程离线完成安装。
|
4月前
|
数据挖掘 数据库 索引
RAG检索模型选型:Bi-Encoder、Cross-Encoder、SPLADE与ColBERT的技术对比
本文解析RAG系统中Bi-Encoder、Cross-Encoder、SPLADE与ColBERT的核心机制,探讨如何平衡高召回与高精准。通过多阶段架构组合稀疏与稠密检索,实现高效准确的语义搜索。
415 3
RAG检索模型选型:Bi-Encoder、Cross-Encoder、SPLADE与ColBERT的技术对比
|
4月前
|
弹性计算 关系型数据库 数据库
阿里云服务器最新活动价格:新用户专享u2a实例和新老用户同享云服务器活动价格参考
阿里云最新活动价格参考,目前在阿里云的活动中,新用户专享的通用算力型u2a实例云服务器,例如2核4G配置新用户购买价格为504.60元/1年起。新老用户同享经济型e实例和AMD旗舰9代计算型c9a、通用型g9a、内存型r9a实例云服务器,例如经济型e实例2核4G配置活动价格为599.93元/1年起,计算型c9a实例4核8G活动价格为3459.05元/1年起。本文为大家整理汇总了2026年截至目前,阿里云新用户专享u2a实例云服务器活动价格和新老用户同享的经济型e及计算型c9a等实例规格云服务器的活动价格,以供大家参考。
|
2月前
|
机器学习/深度学习 人工智能 搜索推荐
学生课堂行为识别数据集(2000张高质量标注)| YOLO训练数据集 AI智慧教育
本数据集含2000张高质量课堂图像,YOLO格式标注6类学生行为(举手、阅读、写作、使用手机、低头、睡觉),覆盖真实教室场景,支持智慧教育中的专注度分析、教学评估与AI模型训练,开箱即用。
|
5月前
|
存储 编解码 数据库
大规模向量检索优化:Binary Quantization 让 RAG 系统内存占用降低 32 倍
本文介绍基于二值化量化的高效RAG系统,通过将float32嵌入压缩为1bit,实现32倍内存缩减。结合Milvus与Hamming距离检索,3600万向量查询仅需30ms。采用过采样与重排序策略,准确率可达95%以上,适合高维大规模场景。
268 0
大规模向量检索优化:Binary Quantization 让 RAG 系统内存占用降低 32 倍
|
4月前
|
存储 弹性计算 Linux
2026阿里云服务器新手选购与操作全指南
在数字化业务部署中,云服务器凭借弹性扩展、成本可控的优势,成为个人开发者和企业搭建网站、运行应用的核心选择。阿里云 ECS(弹性计算服务)作为国内主流云服务产品,地域节点丰富、配置灵活,但新手常因不熟悉购买路径与参数选择感到困惑。下面从购买前准备、核心购买方式、配置选择要点到控制台基础操作,梳理新手所需的全流程知识,助力高效完成上云部署。
|
4月前
【Azure App Service】App Service 遇见 not enough space on the disk
App Service应用提示“磁盘空间不足”时,可通过PowerShell脚本快速统计c:\home和c:\local目录下各文件夹大小,定位大文件并删除,释放空间。
151 9