TorchAcc:基于 TorchXLA 的分布式训练框架

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: 阿里云研究员、阿里云人工智能平台 PAI 技术负责人--林伟在GTC 2024 大会 China AI Day 线上中文演讲专场上介绍了TorchAcc,这是一个基于 PyTorch/XLA 的大模型分布式训练框架。

演讲人:林伟,阿里云研究员,阿里云人工智能平台 PAI 技术负责人


本文旨在探讨阿里云 TorchAcc,这是一个基于 PyTorch/XLA 大模型分布式训练框架。


过去十年 AI 领域的显著进步,关键在于训练技术的革新和模型规模的快速攀升。尽管大模型展现了堪比人类的理解力,但其训练却对算力提出了极高的要求。唯有配备充足的计算资源,方能在海量数据上有效训练大模型,确保其在有限时间内实现优质收敛。

image.png

图片来源于 GTC 2024大会China AI Day 线上专场的演讲《TorchAcc:基于TorchXLA的分布式训练框架》


根据上图左侧图表显示,过去五年,大模型规模的增长态势尤为突出,平均每两年大小翻 15 倍;而对于 Transformer 为代表的语言模型以及多模态模型而言,其规模膨胀速度更加惊人,每隔两年以 750 倍剧增。对比之下,右侧图表揭示了一个明显的矛盾点:不论是单个 GPU 的计算能力抑或是 GPU 显存容量的发展速度,都无法跟上模型规模如此急剧的扩张步伐。这一现实状况直接催生了对分布式训练的迫切需求。分布式训练不再局限于以往单纯的数据并行模式,而是在此基础上,更加重视并采取模型并行策略,以弥补单个计算单元算力与存储提升速度相对于模型规模增长的滞后性。


在分布式训练实践中,开发人员普遍认同,构建模型并行的分布式训练系统相比数据并行更为复杂。数据并行从分布式角度来看,其逻辑相对直接和简洁,因为每个计算节点执行的任务本质上是对等且一致的。在这种情况下,只需在训练过程末尾插入 AllReduce 步骤,将各个工作节点(worker)独立计算出的梯度差异累加整合,然后求平均值,并将最终梯度结果广播至所有参与工作的节点,用以同步更新全局模型参数。


这类简单的分布式训练范式,确实呈现出类似单机计算的特点,主要涉及全局梯度同步的 AllReduce。然而步入大模型时代,由于模型规模过大,已无法容纳于单个 GPU 之内,我们就必须采用模型并行策略,其开发难度也就陡然上升了。


原因是,模型并行需要根据模型的规模和结构来决定如何恰当地“分割”模型,即将其分割为多个可以平衡计算负载的模块。在不同的分割策略下,模型在各个节点上算子的算法实现方式会发生变化,同时,不同分割方法还会引起节点间通信原语的差异,需要精心选择最优分割方案以及配套的通信原语。


在模型分割完成后,接下来的任务就是选用适合的通信原语,并精细地调度各个算子及其相关的通信操作,力求最大化计算与网络通信的重叠(overlap),以充分发挥底层计算资源的效率。正是由于存在多种可能的分割选项与调度决策,寻求最优模型并行策略的复杂性明显高于数据并行,对开发者的技巧和经验提出了更高的要求。


image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


本文将围绕四个核心方面展开。首个议题是如何在 TorchAcc 中实现多样化的并行策略,涵盖了常规的数据并行,以及当下备受关注的 FSDP(Fully Sharded Data Parallel,又称 ZeRO (Zero Redundancy Optimizer)) 。此外,还包括了模型并行的各种形态,诸如算子并行,即 Tensor Parallelism,以及流水线并行(Pipeline Parallelism)等。


TorchAcc 的一大亮点在于其能够自动探寻并有机整合各类并行策略,并为用户提供高度自动化的分布式策略配置方案;与此同时,为了满足高级开发者的定制化需求,TorchAcc 还提供了半自动化的控制接口,允许用户介入并调整自动探索并行策略的过程,从而在兼顾灵活性的同时,最大程度地提升训练效率和资源利用率。


通过上述方式,TorchAcc 有效地助力算法开发者将精力集中于模型自身的结构设计、训练方法的优化,以及追求模型收敛性能的提升上,而非花费精力在分布式训练的具体实现细节。TorchAcc 将智能化地协助开发者探寻并实现最佳的分布式训练方案,从而显著提升计算资源利用效率和算法迭代效率。


其次,模型并行技术的必要性是因为大模型尺寸超出单个 GPU 显存容量的限制。显存容量对于模型训练至关重要,如何打破显存瓶颈,对于提升分布式训练的整体效率来说至关重要。因此,TorchAcc 提供了一种显存智能分配器,通过对显存资源的精细化调度与地址分配策略,最大限度地提高模型并行训练时的效率,确保模型能充分利用现有的显存地址空间。


再者,随着模型结构日益复杂,且规模不断增大,用户对计算资源的需求也在持续攀升,因此,进一步优化模型在训练过程中的计算密集度及减少访存开销也非常关键。


最后,考虑到当前数据中心基础设施的发展趋势,大模型训练对网络条件的要求日渐严苛。现代数据中心服务器间的互联带宽已达到 TB 级别,以满足大规模模型并行训练对高速数据交换的需求。然而,模型并行所带来的复杂通信模式与高频次的数据交互亦会对整体训练效率构成挑战。因此,如何有效利用网络带宽,减少通信过程在迭代计算中占据的时间比例,也就成了训练效率提升的另一重要因素。


在具体实现上,TorchAcc 通过一系列技术手段,成功地将用户在前端,无论是基于 PyTorch 还是 TensorFlow 构建的模型训练过程转化为统一的中间表示层(Model IR)的 graph。其中,对于 TensorFlow 而言,因其自身就是一种计算图模型,转化过程相对直接,而对于 PyTorch,我们采用了符号式追踪(symbolic tracing)以及 LazyTensor 等技术捕获计算图,进而转化为 IR Graph。


image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


基于中间表示层(IR Graph)的构建,TorchAcc 实施了一系列多元化的优化策略,涵盖计算优化、存储优化、通信优化以及分布式策略优化,IR Graph 以各类组合并反复执行这些优化的 Pass 后,最终得到一个最优的执行 Plan。然后交由底层 Backend 执行,以实现模型训练性能的最大化提升。


通过这一整套方案,TorchAcc 在多个模型的分布式训练场景中表现出了显著的性能优势。部分模型的训练过程得以实现高达 3 倍的性能提速,充分证明了 TorchAcc 在解决分布式训练难题上的高效性和实用性。


image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


这张图片主要展示了 TorchAcc 的框架总体架构。TorchAcc 以 Pytorch/XLA 为基础,并 TorchAcc 依托于 OpenXLA,构建了一套大模型训练加速框架。TorchAcc 在处理使用不同前端构建的模型时,会灵活采用适宜的图捕获技术,如 Symbolic Trace 和 LazyTensor,进而生成两种不同层级的图表示:FX Graph 和 HLO Graph。其中,FX Graph 位于较高抽象层次,而 HLO Graph 则更为底层。


基于捕获到的模型计算图,TorchAcc 即可进一步展开了四类优化工作,即前文提及的计算优化、存储优化、通信优化以及分布式策略优化。

image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


在分布式策略优化层面,TorchAcc 支持业界广泛使用的各种并行策略,并能够灵活地结合这些策略对给定模型进行有效的并行化处理。具体而言,对于数据并行 DP(Data Parallelism)、流水并行 PP(Pipeline Parallelism)以及 FSDP(Fully Sharded Data Parallel, 也称为 ZeRO)这三种分布式策略,其实现和优化都是在 FX Graph 这一较高抽象层次上完成的。


选择在 FX Graph 层面对并行策略进行操作的原因在于,这一层级所包含的关于计算图结构和操作的信息已足够丰富,足以支撑开发人员设计出适应不同并行策略的优化方案。相较于在更低层的 HLO Graph 上直接进行优化,由于 FX Graph 具有更高的抽象性和概括性,在这一层面上进行优化的成本通常较低,更容易实施高效且针对性强的分布式策略调整。


以流水并行作为例子,系统能够自动检测 FX Graph 层级上的不同阶段,并确定合适的分割点,从而有效地将模型分割为多个连续执行的阶段,实现流水线并行化。在此过程中,我们可以利用 FX Graph 提供的详细计算结构信息来进行智能分割。


至于 Tensor Parallelism (张量并行)和 Sequence Parallelism (序列并行)这两种更为复杂的并行策略,它们要求更为细致精确的信息以便进行决策。为了实现这一点,系统需要对前向传播和反向传播的整个计算图的执行计划来进行分析。这时的工作主要在 HLO 这一低级别表示层面上进行。

通过利用 PyTorch/XLA 提供的 mark sharding 接口,系统能够在模型参数上添加相应的拆分标记,然后将这些拆分信息传递给 OpenXLA 的 SPMD 优化 Pass,进而触发计算图的拆分、优化、推导和重写过程,最终实现自动的 Tensor Parallelism 和 Sequence Parallelism 功能。


image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


在算子优化层面,TorchAcc 引入 FlashAttention 技术来提升 Attention 模块的执行效率。首先,通过 XLA 的 custom call 功能,将 FlashAttention 的实现无缝地融入到了 OpenXLA 编译器和运行时框架中。这意味着 FlashAttention 可以直接在 XLA 内核层级被执行,从而充分利用硬件加速能力。


在整合过程中,要处理好在 PyTorch 与 XLA 之间 Tensor 数据的传递问题,确保在两个系统间转换时的数据一致性与性能优化,同时,还要妥善处理 FlashAttention内部参数传递等细节问题,保证在并行计算和优化的过程中,这些关键参数能够正确且高效地应用到计算中,进一步提升模型在执行注意力机制部分的运算速度和资源利用率。


为了用户能便捷地使用 FlashAttention 优化功能,我们提供了两种接口,用户也可以直接通过 Python 接口调用预先写好的 FlashAttention 算子,第三种方法是用户可以使用我们在 OpenXLA 上写好的 Pattern Match Pass,该 Pass 能够自动识别计算图中的 Attention Block,并将这部分计算结构提取出来,替换为FlashAttention 的 custom call。这样设计的优势在于,既能充分利用 XLA 原本就十分出色的 Kernel fusion 等算子优化功能,又能结合 FlashAttention 带来的先进计算优化技术。


image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


在 Llama 2-7B 模型的性能测试中,我们能够明显观察到上述计算优化带来的效果。通过利用 XLA 自身的优化技术,尤其是 kernel fusion,我们将大量的访存密集型算子做了有效合并,从而大幅减少其数量,在叠加 FlashAttention 后,优化性能进一步提升。

image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


在通信优化层面,我们主要完成了三项核心任务以提升分布式训练效率:首先,我们合并了一些零散的 collective 通讯算子,通过减少算子数量来降低通讯开销和调度复杂度,其次,我们将合并的 collective 通讯算子移至独立的 CUDA Stream 上执行,这样一来,就能够异步实现计算与通讯的重叠执行。最后,我们充分利用了 OpenXLA 的 Latency Hiding Scheduler 功能,对通讯算子的调度进行了精细优化,使其尽早启动和执行,从而增强通讯与计算之间的重叠效果。


image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


通过在 Llama2 -7B 模型上进行的端到端多机性能测试,我们发现,应用了通讯优化策略后,在 128 张 GPU 卡上进行分布式训练,优化后的加速比从原来的 88 提升到了 116,通过 timeline 图我们也可以直观地看到,优化后的通讯算子更加有序,并且能够更好地和计算重叠执行。


image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


本文最后一个章节绍 TorchAcc 的显存优化功能,该功能通过优化计算图中算子的执行顺序以及 Tensor 在显存中的地址分配,来降低显存开销。

如图举例说明,假设有一个包含四个算子 V0、V1、V2、V3 的计算图,如果不控制算子执行顺序,如左图所示按照 V0-V1-V2-V3 的顺序执行,若每个 Tensor 按照默认方式进行显存地址申请,则可能出现如 B 图左半部分所示的情况,即显存容量不足以容纳所有 Tensor,导致 out of memory 错误。

然而,如果我们能够预判并精细管理内存分配,即在分配地址时预知后续执行的算子序列,即可如 B 图右半部分所示进行更优的显存布局,使得整体计算可在有限显存内顺利完成。更进一步,通过精确控制执行顺序,比如按照 V0-V2-V1-V3 的方式执行,可以进一步压缩显存需求至原始需求的 70% 左右。

这一理念是基于 XLA 中间表示层已有的 scheduler 和 buffer 管理机制,我们在此基础上提出了更先进的显存优化方法。目前业界存在多种优化显存分配的方法,如启发式算法、约束求解等,但这些方法往往难以兼顾时效性和高效性,在实际生产环境的集群中应用时可能存在局限性。

image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》


在训练场景中实现有效且高效的显存优化是一项极具挑战的任务,原因主要包括以下几个方面:

  1. NP-Hard 问题本质:由于模型的规模、算子的种类繁多,以及算子间显存分配的复杂性,显存优化问题成为一个典型的 NP-hard 问题,即找到全局最优解在计算上通常是不可行的。
  2. 算子执行灵活性:训练过程中,前向传播、反向传播和权重更新等操作具有很高的灵活性,特别是在权重更新方面,梯度产生后随时可以被用于权重更新,但不同的执行时机会影响显存的申请和释放,增加了优化难度。
  3. 显存复用复杂性:在训练过程中,前向和反向传播可以通过复用显存减少重新计算,但 Tensor 生命周期的多样性和尺寸的变化使得显存复用变得极为复杂,这对启发式算法等传统优化手段构成了严峻挑战。


为了解决上述难题,我们采取了一种分治策略:

  1. Memory-aware Weight Update Scheduler:引入了显存感知的权重更新调度器,它会根据梯度产生的时机、使用的优化器类型以及当前显存资源状况,选择合适的权重更新时间点,避免即时更新加重显存压力,特别是对于复杂的优化器如 Adam,需考虑动量和其他变量的存储。
  2. Graph 分割与局部优化:将大计算图根据关键节点 (memory insensitive operator) 分割成多个内存无关性的子图,子图间执行顺序固定,而子图内部的执行顺序则可以多样化。通过这种方式,可以将复杂的全局线性规划问题分解成多个局部问题,在子图范围内采用高效的优化方法,如线性规划求解最优执行顺序。

通过上述分治策略,最终我们能够聚合这些子图的求解结果,这也就是我们提出的 ROAM (Reorder Operators and Arrange Tensors Address to Reduce Memory Usage) 这一内存优化探索方式。

上述方法可以成功实现对显存优化问题的高效处理。实验结果显示,与原生 PyTorch、启发式算法以及 Facebook 近期基于整数线性规划的优化方法等 baseline 相比,ROAM 分别节省了约 16%、13% 和 27% 的显存开销,且在优化时长和可扩展性方面表现出色,证实了这种方法的有效性。

image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

image.png

图片来源于 GTC 2024 大会 China AI Day 线上专场的演讲《TorchAcc:基于 TorchXLA 的分布式训练框架》

从另一个维度衡量效果,我们考察了算法求解的时间开销。实验证明,在常见的深度学习场景中,我们的优化算法能够在短短几分钟内得出优化结果。从右图所示对比中可以看出,相较于 Facebook 最近提出的 MODeL(一种基于线性规划的优化方法),我们的方法在求解时间上实现了显著的缩减。原因在于,MODeL 在处理大规模图时并未对其进行有效分割,而我们的方法通过引入 memory-aware weight update scheduler 和子图划分策略,有效地降低了优化问题的空间复杂度,从而提高了求解效率。

综上所述,TorchAcc 在显存优化、计算优化、通信优化以及并行策略优化等方面均取得显著成效,全方位提升了分布式训练的效率与性能。


以上内容来源于 GTC 2024 大会 China AI Day 线上中文演讲专场。扫描图片二维码或登录大会官网,观看演讲视频,并可下载讲义。

image.png

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
2月前
|
Java 数据库
在Java中使用Seata框架实现分布式事务的详细步骤
通过以上步骤,利用 Seata 框架可以实现较为简单的分布式事务处理。在实际应用中,还需要根据具体业务需求进行更详细的配置和处理。同时,要注意处理各种异常情况,以确保分布式事务的正确执行。
|
2月前
|
消息中间件 Java Kafka
在Java中实现分布式事务的常用框架和方法
总之,选择合适的分布式事务框架和方法需要综合考虑业务需求、性能、复杂度等因素。不同的框架和方法都有其特点和适用场景,需要根据具体情况进行评估和选择。同时,随着技术的不断发展,分布式事务的解决方案也在不断更新和完善,以更好地满足业务的需求。你还可以进一步深入研究和了解这些框架和方法,以便在实际应用中更好地实现分布式事务管理。
|
10天前
|
存储 监控 数据可视化
常见的分布式定时任务调度框架
分布式定时任务调度框架用于在分布式系统中管理和调度定时任务,确保任务按预定时间和频率执行。其核心概念包括Job(任务)、Trigger(触发器)、Executor(执行器)和Scheduler(调度器)。这类框架应具备任务管理、任务监控、良好的可扩展性和高可用性等功能。常用的Java生态中的分布式任务调度框架有Quartz Scheduler、ElasticJob和XXL-JOB。
183 66
|
21天前
|
机器学习/深度学习 数据可视化 TensorFlow
使用Python实现深度学习模型的分布式训练
使用Python实现深度学习模型的分布式训练
164 73
|
3天前
|
数据采集 人工智能 分布式计算
MaxFrame:链接大数据与AI的高效分布式计算框架深度评测与实践!
阿里云推出的MaxFrame是链接大数据与AI的分布式Python计算框架,提供类似Pandas的操作接口和分布式处理能力。本文从部署、功能验证到实际场景全面评测MaxFrame,涵盖分布式Pandas操作、大语言模型数据预处理及企业级应用。结果显示,MaxFrame在处理大规模数据时性能显著提升,代码兼容性强,适合从数据清洗到训练数据生成的全链路场景...
15 5
MaxFrame:链接大数据与AI的高效分布式计算框架深度评测与实践!
|
6天前
|
人工智能 弹性计算 监控
分布式大模型训练的性能建模与调优
阿里云智能集团弹性计算高级技术专家林立翔分享了分布式大模型训练的性能建模与调优。内容涵盖四大方面:1) 大模型对AI基础设施的性能挑战,强调规模增大带来的显存和算力需求;2) 大模型训练的性能分析和建模,介绍TOP-DOWN和bottom-up方法论及工具;3) 基于建模分析的性能优化,通过案例展示显存预估和流水线失衡优化;4) 宣传阿里云AI基础设施,提供高效算力集群、网络及软件支持,助力大模型训练与推理。
|
17天前
|
分布式计算 大数据 数据处理
技术评测:MaxCompute MaxFrame——阿里云自研分布式计算框架的Python编程接口
随着大数据和人工智能技术的发展,数据处理的需求日益增长。阿里云推出的MaxCompute MaxFrame(简称“MaxFrame”)是一个专为Python开发者设计的分布式计算框架,它不仅支持Python编程接口,还能直接利用MaxCompute的云原生大数据计算资源和服务。本文将通过一系列最佳实践测评,探讨MaxFrame在分布式Pandas处理以及大语言模型数据处理场景中的表现,并分析其在实际工作中的应用潜力。
55 2
|
2月前
|
存储 Java 关系型数据库
在Spring Boot中整合Seata框架实现分布式事务
可以在 Spring Boot 中成功整合 Seata 框架,实现分布式事务的管理和处理。在实际应用中,还需要根据具体的业务需求和技术架构进行进一步的优化和调整。同时,要注意处理各种可能出现的问题,以保障分布式事务的顺利执行。
89 6
|
2月前
|
数据库
如何在Seata框架中配置分布式事务的隔离级别?
总的来说,配置分布式事务的隔离级别是实现分布式事务管理的重要环节之一,需要认真对待和仔细调整,以满足业务的需求和性能要求。你还可以进一步深入研究和实践 Seata 框架的配置和使用,以更好地应对各种分布式事务场景的挑战。
47 6
|
2月前
|
消息中间件 运维 数据库
Seata框架和其他分布式事务框架有什么区别
Seata框架和其他分布式事务框架有什么区别
35 1