大模型如何训练百万 Token 上下文:上下文并行与 Ring Attention

简介: 上下文窗口暴增至千万级,但硬件难承其重:405B模型单精度权重就需6.5TB内存。为突破显存瓶颈,上下文并行与Ring Attention应运而生——将长序列切分至多卡,边传边算;Zig-Zag分配更实现因果注意力下的负载均衡。高速互连(NVLink/InfiniBand)已成刚需。

只用了几年时间,上下文窗口就从 4k 膨胀到 1000 万。Meta 发布的 Llama 4 Scout 的时候说这个模型支持 1000 万 Token,是 Llama 3 那 128k 的 78 倍。而Google Gemini 3 Pro 是 100 万,Claude 4 也桐乡市100万。

一次推理跑完整个代码库、几百篇论文、连续好几天的对话记录在技术上可行了,但问题是硬件跟不上。

405B 参数的模型,32 位精度下光权重就要 6.5TB 内存。再算上梯度、状态、激活值,后者还随上下文长度二次方增长。单台 NVIDIA HGX B300 配了 2.3TB HBM3e都不够。

这就逼着必须做多节点分布式训练和推理,几十上百块 NVIDIA Blackwell GPU 、NVLink 再加上 InfiniBand,就成了数据中心的标配。所以难点就变味了 GPU 之间的通信瓶颈。

并行化基础

模型或数据集超出单卡容量,就得上并行策略,但是每种策略本质上都是拿通信开销换内存空间。

数据并行是最直接的方案:整个模型复制到每张卡上,训练数据切开,每张卡跑不同的 batch跑完一步同步梯度。适合小模型,计算是瓶颈、内存不是问题的场景。

模型并行针对大模型:单卡装不下,就把模型拆开,不同的层放不同的卡上,按顺序跑。405B 这种规模只能这样,并且下游的卡得等上游算完中间是有空转的。

张量并行更极端:连单个矩阵乘法都塞不进一张卡。就需要把矩阵按行或按列切开,分到各卡上算,再通过 all-reduce 合起来。

但这些都有共同的局限。模型大、上下文又长到几百万 Token,张量并行也顶不住。因为注意力的二次方内存增长太凶,激活值直接占满显存。128k 上下文的激活值内存是 8k 的 16 倍,这个目前没办法,因为就是这么夸张。

上下文并行与序列并行

序列并行和上下文并行都是在设备间切序列来省内存,但切法不一样。

序列并行配合张量并行使用,只切那些非矩阵乘法的操作,比如层归一化、dropout。张量并行管不到的地方,序列并行接手,每张卡处理一部分激活值。两者配合能把序列撑长一些,但到 128k 以上还是会有问题,因为注意力的二次方增长是绕不过去。

上下文并行更彻底:整个序列在所有模块里都切开,包括注意力。每个操作拿到的都是分区后的序列。百万级上下文的训练就靠这个,把激活值的内存占用分摊到各卡上。

注意力一直是最麻烦的问题,因为模型的其他操作基本都是逐 Token 独立处理并行起来很自然。但注意力不行,每个 Token 都要"看"序列里所有其他 Token。序列切到多张卡上之后,GPU 1 的 Token 怎么看 GPU 2 的 Token?直接等数据传完再算,整个流水线就卡住了。

Ring Attention 就是来解决这个问题的,让多节点多卡的大模型训练和推理能在大规模数据中心里跑起来。

Zig Zag Ring Attention:通信和计算重叠

Ring Attention 把 GPU 组织成环形拓扑。每张卡的工作流程是这样的:持有序列中 Q、K、V 张量的一个分块;用本地的 K 和 V 给自己的 Q 分块算注意力;把 K 和 V 传给环里的下一张卡;从上一张卡接收 K 和 V;循环往复,直到所有 Q Token 都跟所有 K/V Token 算完注意力。

关键在于计算和通信是重叠的。GPU 1 拿着当前的 K/V 分块算注意力的时候,同时在从 GPU 0 接收下一批分块。通信延迟减少了,因为不用干等数据全到了再开算。

GPT 这类自回归模型有个额外的麻烦:Token 只能看前面的 Token不能看后面的。所以会导致负载不均衡有些卡会空转,Zig-Zag Ring Attention 解决这个问题的办法是交错分配,不是按顺序切块而是 GPU 0 拿 Token [0, 4, 8...],GPU 1 拿 [1, 5, 9...],以此类推。每张卡都拿到早期和晚期 Token 的混合,因果注意力计算时负载就均衡了环里不会有卡闲着。

但是代价是索引逻辑稍微复杂一点,不过大规模场景下性能收益很可观,因果掩码下也能做到接近满 GPU 利用率。

上下文并行与 Ring Attention 常见问题

上下文并行把输入序列切到多张 GPU 上,突破训练时的内存限制。跟张量并行、数据并行不同,它在所有模型模块里都切序列维度。单卡装不下的百万级 Token 上下文,只有靠这个才能训。

Ring Attention 把 GPU 排成环,每张卡一边算当前数据的注意力,一边把键值对往下传。通信和计算重叠,全对全的注意力计算不用等完整序列数据到齐,GPU 不会干等。

而序列并行只切非矩阵乘法操作(层归一化之类的),配合张量并行用。上下文并行在所有模块里都切序列,包括注意力。超过 128k Token 的上下文必须用后者,因为激活值内存二次方增长太猛了。

为什么 Zig-Zag Ring Attention 比标准 Ring Attention 更好?

Zig-Zag 用交错分配代替顺序分配,因果掩码计算时各卡负载更均衡。标准 Ring Attention 会让后面的卡等前面的分块,造成计算空闲。Zig-Zag 把早期和晚期 Token 均匀撒到各卡上,避免这个问题。

那么训练百万级 Token 上下文的模型需要什么硬件?

多节点 GPU 集群,配 HBM 内存,加高速互连——NVIDIA NVLink 1.8TB/s 或者 InfiniBand。405B 参数模型 32 位精度从头训练加推理,4 台 NVIDIA HGX B300 的机架部署是个不错的起点。

总结

上下文并行本质上是拿通信开销换内存空间,而网络带宽是最要命的瓶颈。Ring Attention 要在 GPU 之间不停交换键值对,传输时间一旦超过计算时间,各卡就会从"边算边传"退化成"等数据"。NVIDIA NVLink 1.8TB/s 加 InfiniBand 的高速互连,在多机架部署里不是可选项是必需品。互连带宽必须匹配 GPU 计算吞吐量,否则上下文并行的效果会大打折扣。

https://avoid.overfit.cn/post/fd6022b9196942ffb737ba306925b6db

by Khang Pham

目录
相关文章
|
3月前
|
人工智能 安全 调度
AI工程vs传统工程 —「道法术」中的变与不变
本文从“道、法、术”三个层面对比AI工程与传统软件工程的异同,指出AI工程并非推倒重来,而是在传统工程坚实基础上,为应对大模型带来的不确定性(如概率性输出、幻觉、高延迟等)所进行的架构升级:在“道”上,从追求绝对正确转向管理概率预期;在“法”上,延续分层解耦、高可用等原则,但建模重心转向上下文工程与不确定性边界控制;在“术”上,融合传统工程基本功与AI新工具(如Context Engineering、轨迹可视化、多维评估体系),最终以确定性架构驾驭不确定性智能,实现可靠价值交付。
700 41
AI工程vs传统工程 —「道法术」中的变与不变
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
深度解析 Google JAX 全栈:带你上手开发,从零构建神经网络
Google凭借JAX AI栈实现AI全栈垂直整合,覆盖模型、应用、云与硬件。JAX结合XLA编译器,Flax构建网络,Optax优化训练,Orbax管理 checkpoint,已在Google及Anthropic、Apple等广泛应用,助力高效大规模AI训练。
592 6
|
2月前
|
人工智能 弹性计算 安全
2026年阿里云OpenClaw一键快速部署教程,轻松搭建专属AI助理!
2026年,打造专属AI数字员工超简单:仅需一台阿里云服务器,几分钟用OpenClaw一键部署,接入百炼大模型,即可实现文档编写、资料检索、脚本运行、报表整理等智能办公能力。本地优先、安全可控、7×24在线。
598 5
|
3月前
|
机器学习/深度学习 算法 5G
MEDLL算法多径参数估计详解
MEDLL算法多径参数估计详解
151 6
|
3月前
|
人工智能 弹性计算 异构计算
2026年阿里云gpu云服务器活动参考:T4、V100、A10卡包月5折起,包年4折起
2026年阿里云推出GPU云服务器优惠活动,涵盖T4、V100、A10等多规格实例,活动截止到3月31日。活动对象为阿里云实名认证用户,新用户首购可享包年包月4折起、按量付费最长100小时1折起的优惠。具体型号与价格如gn7i-c32g1.8xlarge(A10卡)3213.99元/月起,gn6v-c8g1.2xlarge(V100卡)3830元/月起。活动支持包年包月与按量付费两种模式,满足AI训练、图形渲染等多场景需求,助力用户低成本开启AIGC之旅。
1631 2
|
存储 弹性计算 人工智能
阿里云文件存储NAS通用型、极速型和文件存储CPFS有什么区别?
阿里云文件存储NAS极速型NAS低时延,适合企业级时延敏感型核心业务;文件存储CPFS拥有高吞吐和高IOPS,适合高性能计算业务;通用型NAS大容量、高性价比、弹性扩展,支持低频介质,适合通用类文件共享业务。
2854 0
阿里云文件存储NAS通用型、极速型和文件存储CPFS有什么区别?
|
3月前
|
人工智能 安全 应用服务中间件
阿里云 Moltbot(原 Clawdbot)全套云服务介绍、部署步骤与使用指南
Moltbot(原Clawdbot)是由PSPDFKit Labs开发的开源自托管AI智能体(AI Agent),核心定位为“可自主执行任务的AI助手”,区别于传统问答式AI工具,其具备屏幕感知、任务规划、操作执行与状态验证的全链路能力,可7×24小时运行在服务器或终端设备上,通过自然语言指令自动完成文件管理、日程安排、邮件处理、代码编写、跨应用协同等自动化任务,数据优先存储于用户自有节点,隐私可控。2026年,阿里云正式上线Moltbot全套云服务,整合轻量应用服务器、无影云电脑、百炼大模型平台等核心资源,提供预置镜像、一键部署、安全优化等全流程支持,大幅降低部署门槛,适配个人、小型团队及企
4495 9
|
1月前
|
机器学习/深度学习 人工智能 机器人
大模型应用:稀疏注意力 vs 滑动窗口:大模型扩窗技术完全解析.58
本文详解大模型“扩窗”核心技术:滑动窗口注意力(快而局部,适合中短文本)与稀疏注意力(兼顾局部+跨步+首尾,支持超长上下文)。二者均通过降低O(n²)计算复杂度至线性,解决大模型长文本处理的内存与算力瓶颈,推动其从聊天工具升级为长文档分析、代码全量理解等实用AI。
512 26
|
3月前
|
自然语言处理 并行计算 PyTorch
用 PyTorch 实现 LLM-JEPA:不预测 token,预测嵌入
本文从零实现LLM-JEPA:将大语言模型与联合嵌入预测架构(JEPA)结合。通过span遮蔽构造context/target双视图,用可训练编码器预测目标编码器在遮蔽位置的归一化嵌入,以余弦距离为对齐损失,并通过EMA稳定训练。代码简洁清晰,逐行注释,助你深入理解JEPA核心思想。
214 6
用 PyTorch 实现 LLM-JEPA:不预测 token,预测嵌入